abedir commited on
Commit
7de41d7
·
verified ·
1 Parent(s): ce699d1

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +35 -0
  2. README.md +201 -5
  3. app.py +451 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ build-essential \
9
+ libsndfile1 \
10
+ ffmpeg \
11
+ git \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements first for better caching
15
+ COPY requirements.txt .
16
+
17
+ # Install Python dependencies
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Copy application code
21
+ COPY app.py .
22
+
23
+ # Copy model directory
24
+ COPY model/ ./model/
25
+
26
+ # Expose port
27
+ EXPOSE 7860
28
+
29
+ # Set environment variables
30
+ ENV PYTHONUNBUFFERED=1
31
+ ENV GRADIO_SERVER_NAME=0.0.0.0
32
+ ENV GRADIO_SERVER_PORT=7860
33
+
34
+ # Run the application
35
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,206 @@
1
  ---
2
- title: Emotion Detector Api
3
- emoji: 🐠
4
- colorFrom: pink
5
  colorTo: purple
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Emotion Detector API
3
+ emoji: 🎧
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
+ app_port: 7860
8
  ---
9
 
10
+ # 🎧 Emotion Detector API
11
+
12
+ Professional RESTful API for emotion recognition in speech using the fine-tuned HuBERT model: **abedir/emotion-detector**
13
+
14
+ ## 🚀 Quick Start
15
+
16
+ ### Health Check
17
+ ```bash
18
+ curl https://YOUR-SPACE-NAME.hf.space/health
19
+ ```
20
+
21
+ ### Predict Emotion
22
+ ```bash
23
+ curl -X POST "https://YOUR-SPACE-NAME.hf.space/predict" \
24
+ -F "file=@audio.wav"
25
+ ```
26
+
27
+ ### Python Example
28
+ ```python
29
+ import requests
30
+
31
+ # Predict emotion
32
+ url = "https://YOUR-SPACE-NAME.hf.space/predict"
33
+ files = {"file": open("audio.wav", "rb")}
34
+ response = requests.post(url, files=files)
35
+ result = response.json()
36
+
37
+ print(f"Emotion: {result['emotion']}")
38
+ print(f"Confidence: {result['confidence']:.2%}")
39
+ ```
40
+
41
+ ## 🎯 Supported Emotions
42
+
43
+ 1. **Angry/Fearful** - Expressions of anger or fear
44
+ 2. **Happy/Laugh** - Joyful or laughing expressions
45
+ 3. **Neutral/Calm** - Neutral or calm speech
46
+ 4. **Sad/Cry** - Expressions of sadness or crying
47
+ 5. **Surprised/Amazed** - Surprised or amazed reactions
48
+
49
+ ## 📡 API Endpoints
50
+
51
+ ### Core Endpoints
52
+ - `GET /` - API welcome and version info
53
+ - `GET /health` - Health check with system status
54
+ - `GET /docs` - **Interactive API documentation (Swagger UI)**
55
+ - `GET /redoc` - Alternative API documentation
56
+ - `GET /model/info` - Model configuration details
57
+ - `GET /emotions` - List of supported emotions
58
+ - `GET /stats` - API and system statistics
59
+ - `GET /version` - API version information
60
+
61
+ ### Prediction Endpoints
62
+ - `POST /predict` - Basic emotion prediction
63
+ - `POST /predict/detailed` - Prediction with audio metadata
64
+ - `POST /predict/base64` - Predict from base64 encoded audio
65
+ - `POST /predict/batch` - Batch processing (max 50 files)
66
+ - `POST /predict/top-k` - Get top K predictions
67
+ - `POST /predict/threshold` - Confidence-based prediction
68
+
69
+ ### Analysis Endpoints
70
+ - `POST /analyze/audio` - Get audio metadata without prediction
71
+
72
+ ## 📦 Response Format
73
+
74
+ ```json
75
+ {
76
+ "emotion": "Happy/Laugh",
77
+ "confidence": 0.8745,
78
+ "probabilities": {
79
+ "Angry/Fearful": 0.0234,
80
+ "Happy/Laugh": 0.8745,
81
+ "Neutral/Calm": 0.0521,
82
+ "Sad/Cry": 0.0178,
83
+ "Surprised/Amazed": 0.0322
84
+ }
85
+ }
86
+ ```
87
+
88
+ ## 🛠️ Integration Examples
89
+
90
+ ### cURL
91
+ ```bash
92
+ # Basic prediction
93
+ curl -X POST "https://YOUR-SPACE-NAME.hf.space/predict" \
94
+ -F "file=@audio.wav"
95
+
96
+ # Detailed prediction
97
+ curl -X POST "https://YOUR-SPACE-NAME.hf.space/predict/detailed" \
98
+ -F "file=@audio.wav"
99
+
100
+ # Top 3 predictions
101
+ curl -X POST "https://YOUR-SPACE-NAME.hf.space/predict/top-k?k=3" \
102
+ -F "file=@audio.wav"
103
+
104
+ # Batch prediction
105
+ curl -X POST "https://YOUR-SPACE-NAME.hf.space/predict/batch" \
106
+ -F "files=@audio1.wav" \
107
+ -F "files=@audio2.wav" \
108
+ -F "files=@audio3.wav"
109
+ ```
110
+
111
+ ### Python
112
+ ```python
113
+ import requests
114
+
115
+ BASE_URL = "https://YOUR-SPACE-NAME.hf.space"
116
+
117
+ # Basic prediction
118
+ with open("audio.wav", "rb") as f:
119
+ response = requests.post(f"{BASE_URL}/predict", files={"file": f})
120
+ result = response.json()
121
+ print(f"Emotion: {result['emotion']}")
122
+ print(f"Confidence: {result['confidence']:.2%}")
123
+
124
+ # Batch prediction
125
+ files = [
126
+ ("files", open("audio1.wav", "rb")),
127
+ ("files", open("audio2.wav", "rb")),
128
+ ("files", open("audio3.wav", "rb"))
129
+ ]
130
+ response = requests.post(f"{BASE_URL}/predict/batch", files=files)
131
+ results = response.json()
132
+ print(f"Processed {results['total_files']} files in {results['processing_time_seconds']:.2f}s")
133
+ ```
134
+
135
+ ### JavaScript
136
+ ```javascript
137
+ // Using Fetch API
138
+ const formData = new FormData();
139
+ formData.append('file', audioFile);
140
+
141
+ fetch('https://YOUR-SPACE-NAME.hf.space/predict', {
142
+ method: 'POST',
143
+ body: formData
144
+ })
145
+ .then(response => response.json())
146
+ .then(data => {
147
+ console.log('Emotion:', data.emotion);
148
+ console.log('Confidence:', data.confidence);
149
+ });
150
+ ```
151
+
152
+ ## 📚 Documentation
153
+
154
+ After deployment, visit:
155
+ - **Swagger UI**: `/docs` - Interactive API testing
156
+ - **ReDoc**: `/redoc` - Beautiful API documentation
157
+
158
+ ## 🔧 Technical Details
159
+
160
+ - **Model**: HuBERT (Hidden-Unit BERT)
161
+ - **Model ID**: abedir/emotion-detector
162
+ - **Sample Rate**: 16kHz (automatic resampling)
163
+ - **Max Duration**: 3 seconds
164
+ - **Supported Formats**: WAV, MP3, FLAC, OGG, M4A, WebM
165
+ - **Framework**: FastAPI + PyTorch + Transformers
166
+
167
+ ## 🎯 Use Cases
168
+
169
+ ✅ Call center sentiment analysis
170
+ ✅ Mental health monitoring
171
+ ✅ Voice assistant emotion detection
172
+ ✅ Gaming and entertainment
173
+ ✅ Media content analysis
174
+ ✅ Research in affective computing
175
+
176
+ ## 🚨 Error Handling
177
+
178
+ All errors return a consistent format:
179
+
180
+ ```json
181
+ {
182
+ "error": "Invalid file format",
183
+ "detail": "Supported formats: .wav, .mp3, .flac, .ogg, .m4a, .webm",
184
+ "timestamp": "2024-02-06T10:30:00"
185
+ }
186
+ ```
187
+
188
+ HTTP Status Codes:
189
+ - `200` - Success
190
+ - `400` - Bad Request (invalid input)
191
+ - `422` - Validation Error
192
+ - `500` - Internal Server Error
193
+
194
+ ## 🔗 Related Links
195
+
196
+ - **Model**: [abedir/emotion-detector](https://huggingface.co/abedir/emotion-detector)
197
+ - **HuBERT Paper**: [arXiv:2106.07447](https://arxiv.org/abs/2106.07447)
198
+ - **FastAPI**: [Documentation](https://fastapi.tiangolo.com/)
199
+
200
+ ## 📄 License
201
+
202
+ Apache 2.0
203
+
204
+ ---
205
+
206
+ **Built with ❤️ using HuBERT, FastAPI, and Transformers**
app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa
4
+ import io
5
+ import time
6
+ from datetime import datetime
7
+ from typing import Optional, List, Dict
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query
9
+ from fastapi.responses import JSONResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
13
+ import base64
14
+ import os
15
+
16
+ # -------------------- CONFIG --------------------
17
+ MODEL_ID = os.getenv("HF_MODEL_ID", "abedir/emotion-detector")
18
+
19
+ label_map = {
20
+ 0: "Angry/Fearful",
21
+ 1: "Happy/Laugh",
22
+ 2: "Neutral/Calm",
23
+ 3: "Sad/Cry",
24
+ 4: "Surprised/Amazed"
25
+ }
26
+ MAX_DURATION = 3.0 # seconds
27
+ API_VERSION = "1.0.0"
28
+
29
+ # -------------------- LOAD MODEL FROM HF HUB --------------------
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ print("=" * 60)
32
+ print("HuBERT Emotion Recognition API - Starting")
33
+ print("=" * 60)
34
+ print(f"Device: {device}")
35
+ print(f"Loading model from Hugging Face Hub: {MODEL_ID}")
36
+
37
+ try:
38
+ processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
39
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
40
+ model.to(device)
41
+ model.eval()
42
+ print("✓ Model loaded successfully from Hugging Face Hub")
43
+ print("=" * 60)
44
+ except Exception as e:
45
+ print("=" * 60)
46
+ print("✗ ERROR: Failed to load model from Hugging Face Hub")
47
+ print("=" * 60)
48
+ print(f"\nError details: {e}\n")
49
+ print("Please ensure:")
50
+ print(f"1. Model ID is correct: {MODEL_ID}")
51
+ print("2. Model repository exists and is accessible")
52
+ print("3. Model contains all required files:")
53
+ print(" - config.json")
54
+ print(" - preprocessor_config.json")
55
+ print(" - model.safetensors")
56
+ print("=" * 60)
57
+ raise
58
+
59
+ sampling_rate = processor.sampling_rate
60
+ max_length = int(MAX_DURATION * sampling_rate)
61
+
62
+ # -------------------- PYDANTIC MODELS --------------------
63
+ class EmotionPrediction(BaseModel):
64
+ emotion: str = Field(..., description="Predicted emotion label")
65
+ confidence: float = Field(..., description="Confidence score (0-1)")
66
+ probabilities: Dict[str, float] = Field(..., description="Probability distribution across all emotions")
67
+
68
+ class BatchPredictionResponse(BaseModel):
69
+ predictions: List[EmotionPrediction]
70
+ total_files: int
71
+ processing_time_seconds: float
72
+
73
+ class HealthResponse(BaseModel):
74
+ status: str
75
+ model_loaded: bool
76
+ device: str
77
+ supported_emotions: List[str]
78
+ api_version: str
79
+ model_id: str
80
+ timestamp: str
81
+
82
+ class ModelInfoResponse(BaseModel):
83
+ model_name: str
84
+ model_id: str
85
+ model_type: str
86
+ num_labels: int
87
+ emotion_labels: Dict[int, str]
88
+ sample_rate: int
89
+ max_duration_seconds: float
90
+ device: str
91
+
92
+ class AudioInfoResponse(BaseModel):
93
+ duration_seconds: float
94
+ sample_rate: int
95
+ num_samples: int
96
+ is_truncated: bool
97
+ is_padded: bool
98
+
99
+ class Base64PredictionRequest(BaseModel):
100
+ audio_base64: str = Field(..., description="Base64 encoded audio file")
101
+ filename: Optional[str] = Field(None, description="Original filename for reference")
102
+
103
+ class ErrorResponse(BaseModel):
104
+ error: str
105
+ detail: str
106
+ timestamp: str
107
+
108
+ # -------------------- FASTAPI APP --------------------
109
+ app = FastAPI(
110
+ title="HuBERT Emotion Recognition API",
111
+ description="Advanced emotion recognition API using HuBERT model - Model: abedir/emotion-detector",
112
+ version=API_VERSION,
113
+ docs_url="/docs",
114
+ redoc_url="/redoc"
115
+ )
116
+
117
+ # Add CORS middleware
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=["*"],
121
+ allow_credentials=True,
122
+ allow_methods=["*"],
123
+ allow_headers=["*"],
124
+ )
125
+
126
+ # -------------------- HELPER FUNCTIONS --------------------
127
+ def get_audio_info(audio: np.ndarray, sr: int) -> AudioInfoResponse:
128
+ """Get information about the audio"""
129
+ duration = len(audio) / sr
130
+ is_truncated = duration > MAX_DURATION
131
+ is_padded = duration < MAX_DURATION
132
+
133
+ return AudioInfoResponse(
134
+ duration_seconds=float(duration),
135
+ sample_rate=sr,
136
+ num_samples=len(audio),
137
+ is_truncated=is_truncated,
138
+ is_padded=is_padded
139
+ )
140
+
141
+ def preprocess_audio(file_bytes: bytes) -> tuple[torch.Tensor, AudioInfoResponse]:
142
+ """Preprocess audio bytes for model input and return audio info"""
143
+ try:
144
+ audio, sr = librosa.load(
145
+ io.BytesIO(file_bytes),
146
+ sr=sampling_rate
147
+ )
148
+
149
+ audio_info = get_audio_info(audio, sr)
150
+
151
+ # Truncate or pad to max_length
152
+ if len(audio) > max_length:
153
+ audio = audio[:max_length]
154
+ else:
155
+ audio = np.pad(audio, (0, max_length - len(audio)))
156
+
157
+ inputs = processor(
158
+ audio,
159
+ sampling_rate=sampling_rate,
160
+ return_tensors="pt"
161
+ )
162
+ return inputs.input_values.to(device), audio_info
163
+ except Exception as e:
164
+ raise HTTPException(status_code=400, detail=f"Error processing audio: {str(e)}")
165
+
166
+ def predict_emotion(input_values: torch.Tensor) -> EmotionPrediction:
167
+ """Run emotion prediction"""
168
+ with torch.no_grad():
169
+ outputs = model(input_values)
170
+ probs = torch.softmax(outputs.logits, dim=1)[0]
171
+ pred_id = torch.argmax(probs).item()
172
+
173
+ return EmotionPrediction(
174
+ emotion=label_map[pred_id],
175
+ confidence=float(probs[pred_id]),
176
+ probabilities={
177
+ label_map[i]: float(probs[i])
178
+ for i in range(len(label_map))
179
+ }
180
+ )
181
+
182
+ # -------------------- ENDPOINTS --------------------
183
+
184
+ @app.get("/", response_model=Dict[str, str])
185
+ async def root():
186
+ """Root endpoint - API welcome message"""
187
+ return {
188
+ "message": "HuBERT Emotion Recognition API",
189
+ "version": API_VERSION,
190
+ "model_id": MODEL_ID,
191
+ "docs": "/docs",
192
+ "health": "/health"
193
+ }
194
+
195
+ @app.get("/health", response_model=HealthResponse)
196
+ async def health():
197
+ """Comprehensive health check endpoint"""
198
+ return HealthResponse(
199
+ status="healthy",
200
+ model_loaded=model is not None,
201
+ device=device,
202
+ supported_emotions=list(label_map.values()),
203
+ api_version=API_VERSION,
204
+ model_id=MODEL_ID,
205
+ timestamp=datetime.now().isoformat()
206
+ )
207
+
208
+ @app.get("/model/info", response_model=ModelInfoResponse)
209
+ async def get_model_info():
210
+ """Get detailed model information"""
211
+ return ModelInfoResponse(
212
+ model_name="HuBERT Emotion Detector",
213
+ model_id=MODEL_ID,
214
+ model_type="Audio Classification - Emotion Recognition",
215
+ num_labels=len(label_map),
216
+ emotion_labels=label_map,
217
+ sample_rate=sampling_rate,
218
+ max_duration_seconds=MAX_DURATION,
219
+ device=device
220
+ )
221
+
222
+ @app.get("/emotions", response_model=Dict[str, List[str]])
223
+ async def list_emotions():
224
+ """List all supported emotion labels"""
225
+ return {
226
+ "emotions": list(label_map.values()),
227
+ "count": len(label_map)
228
+ }
229
+
230
+ @app.post("/predict", response_model=EmotionPrediction)
231
+ async def predict(
232
+ file: UploadFile = File(..., description="Audio file (.wav, .mp3, .flac, .ogg)")
233
+ ):
234
+ """Predict emotion from uploaded audio file"""
235
+ if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')):
236
+ raise HTTPException(
237
+ status_code=400,
238
+ detail="Invalid file format. Supported: .wav, .mp3, .flac, .ogg, .m4a, .webm"
239
+ )
240
+
241
+ audio_bytes = await file.read()
242
+ input_values, _ = preprocess_audio(audio_bytes)
243
+ return predict_emotion(input_values)
244
+
245
+ @app.post("/predict/detailed", response_model=Dict)
246
+ async def predict_detailed(
247
+ file: UploadFile = File(..., description="Audio file")
248
+ ):
249
+ """Predict emotion with detailed audio information"""
250
+ if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')):
251
+ raise HTTPException(
252
+ status_code=400,
253
+ detail="Invalid file format. Supported: .wav, .mp3, .flac, .ogg, .m4a, .webm"
254
+ )
255
+
256
+ audio_bytes = await file.read()
257
+ input_values, audio_info = preprocess_audio(audio_bytes)
258
+ prediction = predict_emotion(input_values)
259
+
260
+ return {
261
+ "prediction": prediction.dict(),
262
+ "audio_info": audio_info.dict(),
263
+ "filename": file.filename,
264
+ "timestamp": datetime.now().isoformat()
265
+ }
266
+
267
+ @app.post("/predict/base64", response_model=EmotionPrediction)
268
+ async def predict_base64(request: Base64PredictionRequest):
269
+ """Predict emotion from base64 encoded audio"""
270
+ try:
271
+ audio_bytes = base64.b64decode(request.audio_base64)
272
+ except Exception as e:
273
+ raise HTTPException(
274
+ status_code=400,
275
+ detail=f"Invalid base64 encoding: {str(e)}"
276
+ )
277
+
278
+ input_values, _ = preprocess_audio(audio_bytes)
279
+ return predict_emotion(input_values)
280
+
281
+ @app.post("/predict/batch", response_model=BatchPredictionResponse)
282
+ async def predict_batch(
283
+ files: List[UploadFile] = File(..., description="Multiple audio files")
284
+ ):
285
+ """Batch prediction for multiple audio files"""
286
+ if len(files) > 50:
287
+ raise HTTPException(
288
+ status_code=400,
289
+ detail="Maximum 50 files per batch request"
290
+ )
291
+
292
+ start_time = time.time()
293
+ predictions = []
294
+
295
+ for file in files:
296
+ if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')):
297
+ continue
298
+
299
+ try:
300
+ audio_bytes = await file.read()
301
+ input_values, _ = preprocess_audio(audio_bytes)
302
+ prediction = predict_emotion(input_values)
303
+ predictions.append(prediction)
304
+ except Exception as e:
305
+ print(f"Error processing {file.filename}: {e}")
306
+ continue
307
+
308
+ processing_time = time.time() - start_time
309
+
310
+ return BatchPredictionResponse(
311
+ predictions=predictions,
312
+ total_files=len(predictions),
313
+ processing_time_seconds=processing_time
314
+ )
315
+
316
+ @app.post("/analyze/audio", response_model=AudioInfoResponse)
317
+ async def analyze_audio(
318
+ file: UploadFile = File(..., description="Audio file to analyze")
319
+ ):
320
+ """Analyze audio file and return metadata without prediction"""
321
+ try:
322
+ audio_bytes = await file.read()
323
+ audio, sr = librosa.load(io.BytesIO(audio_bytes), sr=sampling_rate)
324
+ return get_audio_info(audio, sr)
325
+ except Exception as e:
326
+ raise HTTPException(
327
+ status_code=400,
328
+ detail=f"Error analyzing audio: {str(e)}"
329
+ )
330
+
331
+ @app.post("/predict/top-k")
332
+ async def predict_top_k(
333
+ file: UploadFile = File(...),
334
+ k: int = Query(3, ge=1, le=5, description="Number of top predictions to return")
335
+ ):
336
+ """Get top-k emotion predictions"""
337
+ if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')):
338
+ raise HTTPException(status_code=400, detail="Invalid file format")
339
+
340
+ audio_bytes = await file.read()
341
+ input_values, _ = preprocess_audio(audio_bytes)
342
+
343
+ with torch.no_grad():
344
+ outputs = model(input_values)
345
+ probs = torch.softmax(outputs.logits, dim=1)[0]
346
+ top_k_probs, top_k_indices = torch.topk(probs, k)
347
+
348
+ top_predictions = [
349
+ {
350
+ "rank": i + 1,
351
+ "emotion": label_map[idx.item()],
352
+ "confidence": prob.item()
353
+ }
354
+ for i, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices))
355
+ ]
356
+
357
+ return {
358
+ "top_predictions": top_predictions,
359
+ "total_emotions": len(label_map)
360
+ }
361
+
362
+ @app.post("/predict/threshold")
363
+ async def predict_with_threshold(
364
+ file: UploadFile = File(...),
365
+ threshold: float = Query(0.5, ge=0.0, le=1.0, description="Confidence threshold")
366
+ ):
367
+ """Predict emotion only if confidence exceeds threshold"""
368
+ if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')):
369
+ raise HTTPException(status_code=400, detail="Invalid file format")
370
+
371
+ audio_bytes = await file.read()
372
+ input_values, _ = preprocess_audio(audio_bytes)
373
+ prediction = predict_emotion(input_values)
374
+
375
+ if prediction.confidence >= threshold:
376
+ return {
377
+ "status": "confident",
378
+ "prediction": prediction.dict()
379
+ }
380
+ else:
381
+ return {
382
+ "status": "uncertain",
383
+ "message": f"Confidence {prediction.confidence:.3f} below threshold {threshold}",
384
+ "best_guess": prediction.dict()
385
+ }
386
+
387
+ @app.get("/stats")
388
+ async def get_stats():
389
+ """Get API statistics and system information"""
390
+ return {
391
+ "model": {
392
+ "name": "HuBERT Emotion Detector",
393
+ "model_id": MODEL_ID,
394
+ "device": device,
395
+ "loaded": model is not None
396
+ },
397
+ "configuration": {
398
+ "max_duration_seconds": MAX_DURATION,
399
+ "sample_rate": sampling_rate,
400
+ "num_emotions": len(label_map)
401
+ },
402
+ "system": {
403
+ "cuda_available": torch.cuda.is_available(),
404
+ "torch_version": torch.__version__
405
+ },
406
+ "api_version": API_VERSION,
407
+ "timestamp": datetime.now().isoformat()
408
+ }
409
+
410
+ @app.get("/version")
411
+ async def get_version():
412
+ """Get API version information"""
413
+ return {
414
+ "api_version": API_VERSION,
415
+ "framework": "FastAPI",
416
+ "model": "HuBERT Emotion Detector",
417
+ "model_id": MODEL_ID,
418
+ "timestamp": datetime.now().isoformat()
419
+ }
420
+
421
+ # -------------------- ERROR HANDLERS --------------------
422
+ @app.exception_handler(HTTPException)
423
+ async def http_exception_handler(request, exc):
424
+ """Custom HTTP exception handler"""
425
+ return JSONResponse(
426
+ status_code=exc.status_code,
427
+ content=ErrorResponse(
428
+ error=exc.detail,
429
+ detail=str(exc),
430
+ timestamp=datetime.now().isoformat()
431
+ ).dict()
432
+ )
433
+
434
+ @app.exception_handler(Exception)
435
+ async def general_exception_handler(request, exc):
436
+ """General exception handler"""
437
+ return JSONResponse(
438
+ status_code=500,
439
+ content=ErrorResponse(
440
+ error="Internal server error",
441
+ detail=str(exc),
442
+ timestamp=datetime.now().isoformat()
443
+ ).dict()
444
+ )
445
+
446
+ # -------------------- STARTUP EVENT --------------------
447
+ @app.on_event("startup")
448
+ async def startup_event():
449
+ """Log startup information"""
450
+ print("API is ready to accept requests!")
451
+ print(f"Visit /docs for interactive API documentation")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.0
2
+ uvicorn[standard]==0.27.0
3
+ torch==2.1.2
4
+ transformers==4.37.2
5
+ librosa==0.10.1
6
+ numpy==1.24.3
7
+ scipy==1.11.4
8
+ soundfile==0.12.1
9
+ python-multipart==0.0.6