monishaaura commited on
Commit
af819b6
Β·
verified Β·
1 Parent(s): 399d8e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +419 -445
app.py CHANGED
@@ -1,445 +1,419 @@
1
- """
2
- FastAPI Backend for Wav2Vec2-Emotion Detection
3
- Uses the superb/wav2vec2-base-superb-er model from Hugging Face
4
- """
5
-
6
- from fastapi import FastAPI, File, UploadFile, HTTPException
7
- from fastapi.middleware.cors import CORSMiddleware
8
- from fastapi.responses import JSONResponse
9
- from contextlib import asynccontextmanager
10
- import torch
11
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor, AutoProcessor, Wav2Vec2FeatureExtractor
12
- import soundfile as sf
13
- import io
14
- import numpy as np
15
- from pydub import AudioSegment
16
- import logging
17
- import os
18
- from typing import Optional
19
-
20
- # Configure logging
21
- logging.basicConfig(level=logging.INFO)
22
- logger = logging.getLogger(__name__)
23
-
24
- # Lifespan context manager for startup/shutdown
25
- @asynccontextmanager
26
- async def lifespan(app: FastAPI):
27
- """
28
- Lifespan context manager for FastAPI.
29
- Loads model on startup and handles cleanup on shutdown.
30
- """
31
- # Startup: Load model
32
- logger.info("πŸš€ Starting up Wav2Vec2 Emotion Detection API...")
33
- load_model()
34
- logger.info("βœ… Startup complete - Model loaded!")
35
- yield
36
- # Shutdown: Cleanup (if needed)
37
- logger.info("πŸ›‘ Shutting down...")
38
-
39
- # Initialize FastAPI app with lifespan
40
- app = FastAPI(
41
- title="Wav2Vec2 Emotion Detection API",
42
- description="Real-time emotion detection from audio using Wav2Vec2 model",
43
- version="1.0.0",
44
- lifespan=lifespan
45
- )
46
-
47
- # Configure CORS - Allow requests from React frontend
48
- # For development, allow all local network origins
49
- # In production, restrict to specific domains
50
- import re
51
-
52
- def check_origin(origin: str, request) -> bool:
53
- """
54
- Check if origin is allowed (localhost, local network, or Vercel)
55
- For development, allows any local network IP
56
- """
57
- if not origin:
58
- return False
59
-
60
- # Allow localhost
61
- if origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:"):
62
- return True
63
-
64
- # Allow local network IPs (192.168.x.x, 10.x.x.x, 172.16-31.x.x)
65
- local_network_pattern = re.compile(
66
- r"http://(192\.168\.\d+\.\d+|10\.\d+\.\d+\.\d+|172\.(1[6-9]|2\d|3[01])\.\d+\.\d+):(5173|3000)"
67
- )
68
- if local_network_pattern.match(origin):
69
- return True
70
-
71
- # Allow Vercel deployments
72
- if "vercel.app" in origin:
73
- return True
74
-
75
- return False
76
-
77
- app.add_middleware(
78
- CORSMiddleware,
79
- allow_origin_func=check_origin, # Use function for dynamic origin checking
80
- allow_credentials=False,
81
- allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
82
- allow_headers=["*"],
83
- expose_headers=["*"],
84
- )
85
-
86
- # Global variables for model and processor
87
- # These will be loaded once when the app starts
88
- model: Optional[Wav2Vec2ForSequenceClassification] = None
89
- processor: Optional[Wav2Vec2Processor] = None
90
- feature_extractor: Optional[Wav2Vec2FeatureExtractor] = None
91
-
92
- # Emotion labels mapping (superb/wav2vec2-base-superb-er outputs)
93
- # The model outputs 6 emotions based on the Emotion Recognition (ER) task
94
- EMOTION_LABELS = [
95
- "neutral", # 0
96
- "happy", # 1
97
- "sad", # 2
98
- "angry", # 3
99
- "calm", # 4
100
- "excited" # 5
101
- ]
102
-
103
-
104
- def load_model():
105
- """
106
- Load the Wav2Vec2-Emotion model and processor from Hugging Face.
107
- This function is called once at startup to initialize the model.
108
- """
109
- global model, processor, feature_extractor
110
-
111
- try:
112
- logger.info("πŸ”„ Loading Wav2Vec2-Emotion model from Hugging Face...")
113
- logger.info("Model: superb/wav2vec2-base-superb-er")
114
-
115
- model_name = "superb/wav2vec2-base-superb-er"
116
-
117
- # Try loading feature extractor first (Wav2Vec2 doesn't always need tokenizer)
118
- logger.info("πŸ“¦ Loading feature extractor...")
119
- try:
120
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
121
- logger.info("βœ… Feature extractor loaded!")
122
- processor = feature_extractor # Use feature extractor as processor
123
- except Exception as e_fe:
124
- logger.warning(f"⚠️ Feature extractor failed: {e_fe}")
125
-
126
- # Try using AutoProcessor
127
- try:
128
- logger.info("πŸ“¦ Trying AutoProcessor...")
129
- processor = AutoProcessor.from_pretrained(model_name)
130
- logger.info("βœ… AutoProcessor loaded successfully!")
131
- except Exception as e1:
132
- logger.warning(f"⚠️ AutoProcessor failed: {e1}")
133
- logger.info("πŸ“¦ Trying Wav2Vec2Processor directly...")
134
- # Fallback to direct processor
135
- try:
136
- processor = Wav2Vec2Processor.from_pretrained(model_name)
137
- logger.info("βœ… Wav2Vec2Processor loaded successfully!")
138
- except Exception as e2:
139
- logger.error(f"❌ All processor methods failed!")
140
- logger.error(f" FeatureExtractor: {e_fe}")
141
- logger.error(f" AutoProcessor: {e1}")
142
- logger.error(f" Wav2Vec2Processor: {e2}")
143
- raise
144
-
145
- # Load the model
146
- logger.info("πŸ“¦ Loading model...")
147
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
148
-
149
- # Set model to evaluation mode (not training)
150
- model.eval()
151
-
152
- logger.info("βœ… Model loaded successfully!")
153
- logger.info(f"πŸ“Š Model device: {next(model.parameters()).device}")
154
-
155
- except Exception as e:
156
- logger.error(f"❌ Error loading model: {str(e)}")
157
- logger.error(f"πŸ“‹ Full error: {repr(e)}")
158
- raise
159
-
160
-
161
- def convert_audio_to_wav(audio_bytes: bytes, input_format: str = "webm") -> bytes:
162
- """
163
- Convert audio bytes to WAV format (16kHz, mono, 16-bit).
164
- The Wav2Vec2 model expects specific audio format.
165
-
166
- Args:
167
- audio_bytes: Raw audio data as bytes
168
- input_format: Input format (webm, mp3, wav, etc.)
169
-
170
- Returns:
171
- WAV audio bytes (16kHz, mono, 16-bit)
172
- """
173
- try:
174
- # If already WAV, just verify format and return
175
- if input_format.lower() == "wav":
176
- logger.info("Audio is already WAV format")
177
- return audio_bytes
178
-
179
- # Try using librosa first (supports more formats, no ffmpeg needed for basic formats)
180
- try:
181
- import librosa
182
- logger.info(f"Attempting to convert {input_format} using librosa...")
183
-
184
- # Load audio with librosa (handles format conversion internally)
185
- audio_array, sample_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
186
-
187
- # Normalize audio
188
- audio_array = librosa.util.normalize(audio_array)
189
-
190
- # Convert to int16 WAV format
191
- audio_int16 = (audio_array * 32767).astype(np.int16)
192
-
193
- # Create WAV file in memory
194
- wav_buffer = io.BytesIO()
195
- sf.write(wav_buffer, audio_int16, 16000, format='WAV', subtype='PCM_16')
196
- wav_bytes = wav_buffer.getvalue()
197
-
198
- logger.info(f"βœ… Successfully converted {input_format} to WAV using librosa")
199
- return wav_bytes
200
-
201
- except Exception as librosa_error:
202
- logger.warning(f"librosa conversion failed: {librosa_error}")
203
-
204
- # Fallback to pydub (requires ffmpeg)
205
- logger.info(f"Falling back to pydub for {input_format}...")
206
- try:
207
- audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format=input_format)
208
-
209
- # Convert to required format:
210
- # - 16kHz sample rate (Wav2Vec2 requirement)
211
- # - Mono (single channel)
212
- # - 16-bit depth
213
- audio = audio.set_frame_rate(16000)
214
- audio = audio.set_channels(1)
215
- audio = audio.set_sample_width(2) # 16-bit = 2 bytes per sample
216
-
217
- # Export to WAV bytes
218
- wav_buffer = io.BytesIO()
219
- audio.export(wav_buffer, format="wav")
220
- wav_bytes = wav_buffer.getvalue()
221
-
222
- logger.info(f"βœ… Successfully converted {input_format} to WAV using pydub")
223
- return wav_bytes
224
-
225
- except Exception as pydub_error:
226
- logger.error(f"pydub conversion also failed: {pydub_error}")
227
- raise Exception(
228
- f"Audio conversion failed. {input_format} format requires ffmpeg. "
229
- f"Please install ffmpeg or convert audio to WAV format first. "
230
- f"Error details: {pydub_error}"
231
- )
232
-
233
- except Exception as e:
234
- logger.error(f"Error converting audio: {str(e)}")
235
- raise
236
-
237
-
238
- def preprocess_audio(audio_bytes: bytes) -> np.ndarray:
239
- """
240
- Preprocess audio for Wav2Vec2 model.
241
- Converts audio bytes to numpy array and normalizes.
242
-
243
- Args:
244
- audio_bytes: WAV audio bytes (16kHz, mono, 16-bit)
245
-
246
- Returns:
247
- Audio array ready for model input (normalized float32, 16kHz)
248
- """
249
- try:
250
- # Read audio using soundfile
251
- audio_buffer = io.BytesIO(audio_bytes)
252
- audio_array, sample_rate = sf.read(audio_buffer, dtype='float32')
253
-
254
- # Verify sample rate is 16kHz (required by Wav2Vec2)
255
- if sample_rate != 16000:
256
- logger.warning(f"Sample rate is {sample_rate}Hz, resampling to 16kHz...")
257
- # Note: pydub already handles this in convert_audio_to_wav
258
-
259
- # Normalize audio to [-1, 1] range if needed
260
- if audio_array.dtype != np.float32:
261
- audio_array = audio_array.astype(np.float32)
262
-
263
- # Ensure mono (single channel)
264
- if len(audio_array.shape) > 1:
265
- audio_array = np.mean(audio_array, axis=1)
266
-
267
- # Normalize to [-1, 1] range
268
- max_val = np.abs(audio_array).max()
269
- if max_val > 0:
270
- audio_array = audio_array / max_val
271
-
272
- return audio_array
273
-
274
- except Exception as e:
275
- logger.error(f"Error preprocessing audio: {str(e)}")
276
- raise
277
-
278
-
279
- def predict_emotion(audio_array: np.ndarray) -> dict:
280
- """
281
- Predict emotion from audio array using Wav2Vec2 model.
282
-
283
- Args:
284
- audio_array: Preprocessed audio array (float32, 16kHz, mono)
285
-
286
- Returns:
287
- Dictionary with emotion label and confidence score
288
- """
289
- global model, processor
290
-
291
- try:
292
- # Use processor to prepare input for model
293
- # This handles tokenization and feature extraction
294
- inputs = processor(
295
- audio_array,
296
- sampling_rate=16000,
297
- return_tensors="pt", # Return PyTorch tensors
298
- padding=True
299
- )
300
-
301
- # Move inputs to same device as model (CPU or GPU)
302
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
303
- inputs = {k: v.to(device) for k, v in inputs.items()}
304
-
305
- # Move model to device if needed
306
- if next(model.parameters()).device != device:
307
- model = model.to(device)
308
-
309
- # Run inference (no gradient computation)
310
- with torch.no_grad():
311
- outputs = model(**inputs)
312
-
313
- # Get predicted class (emotion label index)
314
- logits = outputs.logits
315
- predicted_class = torch.argmax(logits, dim=-1).item()
316
-
317
- # Get probabilities for all emotions using softmax
318
- probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
319
-
320
- # Get confidence (probability of predicted emotion)
321
- confidence = float(probabilities[predicted_class])
322
-
323
- # Map class index to emotion label
324
- emotion_label = EMOTION_LABELS[predicted_class]
325
-
326
- # Create probability distribution for all emotions
327
- emotion_probs = {
328
- EMOTION_LABELS[i]: float(prob)
329
- for i, prob in enumerate(probabilities)
330
- }
331
-
332
- logger.info(f"🎭 Detected emotion: {emotion_label} (confidence: {confidence:.2%})")
333
- logger.info(f"πŸ“Š Probability distribution: {emotion_probs}")
334
-
335
- return {
336
- "emotion": emotion_label,
337
- "confidence": confidence,
338
- "probabilities": emotion_probs
339
- }
340
-
341
- except Exception as e:
342
- logger.error(f"Error during prediction: {str(e)}")
343
- raise
344
-
345
-
346
- # Model loading is now handled by lifespan context manager above
347
-
348
-
349
- @app.get("/")
350
- async def root():
351
- """Health check endpoint."""
352
- return {
353
- "status": "healthy",
354
- "service": "Wav2Vec2 Emotion Detection API",
355
- "model": "superb/wav2vec2-base-superb-er",
356
- "emotions": EMOTION_LABELS
357
- }
358
-
359
-
360
- @app.get("/health")
361
- async def health_check():
362
- """Detailed health check endpoint."""
363
- return {
364
- "status": "healthy",
365
- "model_loaded": model is not None and processor is not None,
366
- "device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
367
- "model_name": "superb/wav2vec2-base-superb-er"
368
- }
369
-
370
-
371
- @app.post("/predict")
372
- async def predict_emotion_endpoint(
373
- audio: UploadFile = File(..., description="Audio file (WAV, MP3, WebM, etc.)")
374
- ):
375
- """
376
- Predict emotion from uploaded audio file.
377
-
378
- Steps:
379
- 1. Receive audio file from frontend
380
- 2. Convert to WAV format (16kHz, mono, 16-bit)
381
- 3. Preprocess audio for model
382
- 4. Run Wav2Vec2 model inference
383
- 5. Return detected emotion and confidence
384
-
385
- Args:
386
- audio: Audio file uploaded from frontend
387
-
388
- Returns:
389
- JSON response with emotion, confidence, and probability distribution
390
- """
391
- try:
392
- # Read uploaded audio file
393
- audio_bytes = await audio.read()
394
- logger.info(f"πŸ“₯ Received audio file: {audio.filename}, size: {len(audio_bytes)} bytes")
395
-
396
- # Determine input format from file extension or MIME type
397
- input_format = "webm" # Default (browser recordings are usually WebM)
398
- if audio.filename:
399
- ext = audio.filename.split(".")[-1].lower()
400
- if ext in ["mp3", "wav", "m4a", "ogg"]:
401
- input_format = ext
402
-
403
- # Convert audio to WAV format (16kHz, mono, 16-bit)
404
- logger.info("πŸ”„ Converting audio to WAV format...")
405
- wav_bytes = convert_audio_to_wav(audio_bytes, input_format=input_format)
406
-
407
- # Preprocess audio for model
408
- logger.info("πŸ”„ Preprocessing audio...")
409
- audio_array = preprocess_audio(wav_bytes)
410
- logger.info(f"βœ… Audio preprocessed: {len(audio_array)} samples at 16kHz")
411
-
412
- # Predict emotion
413
- logger.info("🧠 Running emotion prediction...")
414
- result = predict_emotion(audio_array)
415
-
416
- # Return result
417
- return JSONResponse(content=result)
418
-
419
- except Exception as e:
420
- logger.error(f"❌ Error in predict endpoint: {str(e)}")
421
- raise HTTPException(
422
- status_code=500,
423
- detail=f"Error processing audio: {str(e)}"
424
- )
425
-
426
-
427
- if __name__ == "__main__":
428
- import uvicorn
429
- import os
430
-
431
- # Get port from environment (cloud platforms like Render set this automatically)
432
- # Default to 8000 for local development
433
- port = int(os.environ.get("PORT", 8000))
434
-
435
- # Check if running in production (cloud environment)
436
- is_production = os.environ.get("ENVIRONMENT", "development") == "production"
437
-
438
- # Run the FastAPI server
439
- uvicorn.run(
440
- "app:app",
441
- host="0.0.0.0", # Listen on all interfaces
442
- port=port, # Use environment port or 8000 for local
443
- reload=not is_production # Only reload in development
444
- )
445
-
 
1
+ """
2
+ FastAPI Backend for Wav2Vec2-Emotion Detection
3
+ Uses the superb/wav2vec2-base-superb-er model from Hugging Face
4
+ """
5
+
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse
9
+ from contextlib import asynccontextmanager
10
+ import torch
11
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor, AutoProcessor, Wav2Vec2FeatureExtractor
12
+ import soundfile as sf
13
+ import io
14
+ import numpy as np
15
+ from pydub import AudioSegment
16
+ import logging
17
+ import os
18
+ from typing import Optional
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Lifespan context manager for startup/shutdown
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ """
28
+ Lifespan context manager for FastAPI.
29
+ Loads model on startup and handles cleanup on shutdown.
30
+ """
31
+ # Startup: Load model
32
+ logger.info("πŸš€ Starting up Wav2Vec2 Emotion Detection API...")
33
+ load_model()
34
+ logger.info("βœ… Startup complete - Model loaded!")
35
+ yield
36
+ # Shutdown: Cleanup (if needed)
37
+ logger.info("πŸ›‘ Shutting down...")
38
+
39
+ # Initialize FastAPI app with lifespan
40
+ app = FastAPI(
41
+ title="Wav2Vec2 Emotion Detection API",
42
+ description="Real-time emotion detection from audio using Wav2Vec2 model",
43
+ version="1.0.0",
44
+ lifespan=lifespan
45
+ )
46
+
47
+ # Configure CORS - Allow requests from React frontend
48
+ # For public API, allow all origins (common for ML APIs)
49
+ # Using allow_origins=["*"] for maximum compatibility
50
+
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=["*"], # Allow all origins for public API
54
+ allow_credentials=False,
55
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
56
+ allow_headers=["*"],
57
+ expose_headers=["*"],
58
+ )
59
+
60
+ # Global variables for model and processor
61
+ # These will be loaded once when the app starts
62
+ model: Optional[Wav2Vec2ForSequenceClassification] = None
63
+ processor: Optional[Wav2Vec2Processor] = None
64
+ feature_extractor: Optional[Wav2Vec2FeatureExtractor] = None
65
+
66
+ # Emotion labels mapping (superb/wav2vec2-base-superb-er outputs)
67
+ # The model outputs 6 emotions based on the Emotion Recognition (ER) task
68
+ EMOTION_LABELS = [
69
+ "neutral", # 0
70
+ "happy", # 1
71
+ "sad", # 2
72
+ "angry", # 3
73
+ "calm", # 4
74
+ "excited" # 5
75
+ ]
76
+
77
+
78
+ def load_model():
79
+ """
80
+ Load the Wav2Vec2-Emotion model and processor from Hugging Face.
81
+ This function is called once at startup to initialize the model.
82
+ """
83
+ global model, processor, feature_extractor
84
+
85
+ try:
86
+ logger.info("πŸ”„ Loading Wav2Vec2-Emotion model from Hugging Face...")
87
+ logger.info("Model: superb/wav2vec2-base-superb-er")
88
+
89
+ model_name = "superb/wav2vec2-base-superb-er"
90
+
91
+ # Try loading feature extractor first (Wav2Vec2 doesn't always need tokenizer)
92
+ logger.info("πŸ“¦ Loading feature extractor...")
93
+ try:
94
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
95
+ logger.info("βœ… Feature extractor loaded!")
96
+ processor = feature_extractor # Use feature extractor as processor
97
+ except Exception as e_fe:
98
+ logger.warning(f"⚠️ Feature extractor failed: {e_fe}")
99
+
100
+ # Try using AutoProcessor
101
+ try:
102
+ logger.info("πŸ“¦ Trying AutoProcessor...")
103
+ processor = AutoProcessor.from_pretrained(model_name)
104
+ logger.info("βœ… AutoProcessor loaded successfully!")
105
+ except Exception as e1:
106
+ logger.warning(f"⚠️ AutoProcessor failed: {e1}")
107
+ logger.info("πŸ“¦ Trying Wav2Vec2Processor directly...")
108
+ # Fallback to direct processor
109
+ try:
110
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
111
+ logger.info("βœ… Wav2Vec2Processor loaded successfully!")
112
+ except Exception as e2:
113
+ logger.error(f"❌ All processor methods failed!")
114
+ logger.error(f" FeatureExtractor: {e_fe}")
115
+ logger.error(f" AutoProcessor: {e1}")
116
+ logger.error(f" Wav2Vec2Processor: {e2}")
117
+ raise
118
+
119
+ # Load the model
120
+ logger.info("πŸ“¦ Loading model...")
121
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
122
+
123
+ # Set model to evaluation mode (not training)
124
+ model.eval()
125
+
126
+ logger.info("βœ… Model loaded successfully!")
127
+ logger.info(f"πŸ“Š Model device: {next(model.parameters()).device}")
128
+
129
+ except Exception as e:
130
+ logger.error(f"❌ Error loading model: {str(e)}")
131
+ logger.error(f"πŸ“‹ Full error: {repr(e)}")
132
+ raise
133
+
134
+
135
+ def convert_audio_to_wav(audio_bytes: bytes, input_format: str = "webm") -> bytes:
136
+ """
137
+ Convert audio bytes to WAV format (16kHz, mono, 16-bit).
138
+ The Wav2Vec2 model expects specific audio format.
139
+
140
+ Args:
141
+ audio_bytes: Raw audio data as bytes
142
+ input_format: Input format (webm, mp3, wav, etc.)
143
+
144
+ Returns:
145
+ WAV audio bytes (16kHz, mono, 16-bit)
146
+ """
147
+ try:
148
+ # If already WAV, just verify format and return
149
+ if input_format.lower() == "wav":
150
+ logger.info("Audio is already WAV format")
151
+ return audio_bytes
152
+
153
+ # Try using librosa first (supports more formats, no ffmpeg needed for basic formats)
154
+ try:
155
+ import librosa
156
+ logger.info(f"Attempting to convert {input_format} using librosa...")
157
+
158
+ # Load audio with librosa (handles format conversion internally)
159
+ audio_array, sample_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
160
+
161
+ # Normalize audio
162
+ audio_array = librosa.util.normalize(audio_array)
163
+
164
+ # Convert to int16 WAV format
165
+ audio_int16 = (audio_array * 32767).astype(np.int16)
166
+
167
+ # Create WAV file in memory
168
+ wav_buffer = io.BytesIO()
169
+ sf.write(wav_buffer, audio_int16, 16000, format='WAV', subtype='PCM_16')
170
+ wav_bytes = wav_buffer.getvalue()
171
+
172
+ logger.info(f"βœ… Successfully converted {input_format} to WAV using librosa")
173
+ return wav_bytes
174
+
175
+ except Exception as librosa_error:
176
+ logger.warning(f"librosa conversion failed: {librosa_error}")
177
+
178
+ # Fallback to pydub (requires ffmpeg)
179
+ logger.info(f"Falling back to pydub for {input_format}...")
180
+ try:
181
+ audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format=input_format)
182
+
183
+ # Convert to required format:
184
+ # - 16kHz sample rate (Wav2Vec2 requirement)
185
+ # - Mono (single channel)
186
+ # - 16-bit depth
187
+ audio = audio.set_frame_rate(16000)
188
+ audio = audio.set_channels(1)
189
+ audio = audio.set_sample_width(2) # 16-bit = 2 bytes per sample
190
+
191
+ # Export to WAV bytes
192
+ wav_buffer = io.BytesIO()
193
+ audio.export(wav_buffer, format="wav")
194
+ wav_bytes = wav_buffer.getvalue()
195
+
196
+ logger.info(f"βœ… Successfully converted {input_format} to WAV using pydub")
197
+ return wav_bytes
198
+
199
+ except Exception as pydub_error:
200
+ logger.error(f"pydub conversion also failed: {pydub_error}")
201
+ raise Exception(
202
+ f"Audio conversion failed. {input_format} format requires ffmpeg. "
203
+ f"Please install ffmpeg or convert audio to WAV format first. "
204
+ f"Error details: {pydub_error}"
205
+ )
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error converting audio: {str(e)}")
209
+ raise
210
+
211
+
212
+ def preprocess_audio(audio_bytes: bytes) -> np.ndarray:
213
+ """
214
+ Preprocess audio for Wav2Vec2 model.
215
+ Converts audio bytes to numpy array and normalizes.
216
+
217
+ Args:
218
+ audio_bytes: WAV audio bytes (16kHz, mono, 16-bit)
219
+
220
+ Returns:
221
+ Audio array ready for model input (normalized float32, 16kHz)
222
+ """
223
+ try:
224
+ # Read audio using soundfile
225
+ audio_buffer = io.BytesIO(audio_bytes)
226
+ audio_array, sample_rate = sf.read(audio_buffer, dtype='float32')
227
+
228
+ # Verify sample rate is 16kHz (required by Wav2Vec2)
229
+ if sample_rate != 16000:
230
+ logger.warning(f"Sample rate is {sample_rate}Hz, resampling to 16kHz...")
231
+ # Note: pydub already handles this in convert_audio_to_wav
232
+
233
+ # Normalize audio to [-1, 1] range if needed
234
+ if audio_array.dtype != np.float32:
235
+ audio_array = audio_array.astype(np.float32)
236
+
237
+ # Ensure mono (single channel)
238
+ if len(audio_array.shape) > 1:
239
+ audio_array = np.mean(audio_array, axis=1)
240
+
241
+ # Normalize to [-1, 1] range
242
+ max_val = np.abs(audio_array).max()
243
+ if max_val > 0:
244
+ audio_array = audio_array / max_val
245
+
246
+ return audio_array
247
+
248
+ except Exception as e:
249
+ logger.error(f"Error preprocessing audio: {str(e)}")
250
+ raise
251
+
252
+
253
+ def predict_emotion(audio_array: np.ndarray) -> dict:
254
+ """
255
+ Predict emotion from audio array using Wav2Vec2 model.
256
+
257
+ Args:
258
+ audio_array: Preprocessed audio array (float32, 16kHz, mono)
259
+
260
+ Returns:
261
+ Dictionary with emotion label and confidence score
262
+ """
263
+ global model, processor
264
+
265
+ try:
266
+ # Use processor to prepare input for model
267
+ # This handles tokenization and feature extraction
268
+ inputs = processor(
269
+ audio_array,
270
+ sampling_rate=16000,
271
+ return_tensors="pt", # Return PyTorch tensors
272
+ padding=True
273
+ )
274
+
275
+ # Move inputs to same device as model (CPU or GPU)
276
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277
+ inputs = {k: v.to(device) for k, v in inputs.items()}
278
+
279
+ # Move model to device if needed
280
+ if next(model.parameters()).device != device:
281
+ model = model.to(device)
282
+
283
+ # Run inference (no gradient computation)
284
+ with torch.no_grad():
285
+ outputs = model(**inputs)
286
+
287
+ # Get predicted class (emotion label index)
288
+ logits = outputs.logits
289
+ predicted_class = torch.argmax(logits, dim=-1).item()
290
+
291
+ # Get probabilities for all emotions using softmax
292
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
293
+
294
+ # Get confidence (probability of predicted emotion)
295
+ confidence = float(probabilities[predicted_class])
296
+
297
+ # Map class index to emotion label
298
+ emotion_label = EMOTION_LABELS[predicted_class]
299
+
300
+ # Create probability distribution for all emotions
301
+ emotion_probs = {
302
+ EMOTION_LABELS[i]: float(prob)
303
+ for i, prob in enumerate(probabilities)
304
+ }
305
+
306
+ logger.info(f"🎭 Detected emotion: {emotion_label} (confidence: {confidence:.2%})")
307
+ logger.info(f"πŸ“Š Probability distribution: {emotion_probs}")
308
+
309
+ return {
310
+ "emotion": emotion_label,
311
+ "confidence": confidence,
312
+ "probabilities": emotion_probs
313
+ }
314
+
315
+ except Exception as e:
316
+ logger.error(f"Error during prediction: {str(e)}")
317
+ raise
318
+
319
+
320
+ # Model loading is now handled by lifespan context manager above
321
+
322
+
323
+ @app.get("/")
324
+ async def root():
325
+ """Health check endpoint."""
326
+ return {
327
+ "status": "healthy",
328
+ "service": "Wav2Vec2 Emotion Detection API",
329
+ "model": "superb/wav2vec2-base-superb-er",
330
+ "emotions": EMOTION_LABELS
331
+ }
332
+
333
+
334
+ @app.get("/health")
335
+ async def health_check():
336
+ """Detailed health check endpoint."""
337
+ return {
338
+ "status": "healthy",
339
+ "model_loaded": model is not None and processor is not None,
340
+ "device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
341
+ "model_name": "superb/wav2vec2-base-superb-er"
342
+ }
343
+
344
+
345
+ @app.post("/predict")
346
+ async def predict_emotion_endpoint(
347
+ audio: UploadFile = File(..., description="Audio file (WAV, MP3, WebM, etc.)")
348
+ ):
349
+ """
350
+ Predict emotion from uploaded audio file.
351
+
352
+ Steps:
353
+ 1. Receive audio file from frontend
354
+ 2. Convert to WAV format (16kHz, mono, 16-bit)
355
+ 3. Preprocess audio for model
356
+ 4. Run Wav2Vec2 model inference
357
+ 5. Return detected emotion and confidence
358
+
359
+ Args:
360
+ audio: Audio file uploaded from frontend
361
+
362
+ Returns:
363
+ JSON response with emotion, confidence, and probability distribution
364
+ """
365
+ try:
366
+ # Read uploaded audio file
367
+ audio_bytes = await audio.read()
368
+ logger.info(f"πŸ“₯ Received audio file: {audio.filename}, size: {len(audio_bytes)} bytes")
369
+
370
+ # Determine input format from file extension or MIME type
371
+ input_format = "webm" # Default (browser recordings are usually WebM)
372
+ if audio.filename:
373
+ ext = audio.filename.split(".")[-1].lower()
374
+ if ext in ["mp3", "wav", "m4a", "ogg"]:
375
+ input_format = ext
376
+
377
+ # Convert audio to WAV format (16kHz, mono, 16-bit)
378
+ logger.info("πŸ”„ Converting audio to WAV format...")
379
+ wav_bytes = convert_audio_to_wav(audio_bytes, input_format=input_format)
380
+
381
+ # Preprocess audio for model
382
+ logger.info("πŸ”„ Preprocessing audio...")
383
+ audio_array = preprocess_audio(wav_bytes)
384
+ logger.info(f"βœ… Audio preprocessed: {len(audio_array)} samples at 16kHz")
385
+
386
+ # Predict emotion
387
+ logger.info("🧠 Running emotion prediction...")
388
+ result = predict_emotion(audio_array)
389
+
390
+ # Return result
391
+ return JSONResponse(content=result)
392
+
393
+ except Exception as e:
394
+ logger.error(f"❌ Error in predict endpoint: {str(e)}")
395
+ raise HTTPException(
396
+ status_code=500,
397
+ detail=f"Error processing audio: {str(e)}"
398
+ )
399
+
400
+
401
+ if __name__ == "__main__":
402
+ import uvicorn
403
+ import os
404
+
405
+ # Get port from environment (cloud platforms like Render set this automatically)
406
+ # Default to 8000 for local development
407
+ port = int(os.environ.get("PORT", 8000))
408
+
409
+ # Check if running in production (cloud environment)
410
+ is_production = os.environ.get("ENVIRONMENT", "development") == "production"
411
+
412
+ # Run the FastAPI server
413
+ uvicorn.run(
414
+ "app:app",
415
+ host="0.0.0.0", # Listen on all interfaces
416
+ port=port, # Use environment port or 8000 for local
417
+ reload=not is_production # Only reload in development
418
+ )
419
+