Nishvaraj commited on
Commit
02fcea5
·
1 Parent(s): 129e624

deploy MMER FastAPI backend on HF Spaces

Browse files
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ RUN apt-get update && apt-get install -y \
4
+ libxcb1 libxcb-render0 libxcb-shm0 libxcb-xfixes0 \
5
+ libglib2.0-0 libsm6 libxext6 libxrender-dev \
6
+ libgomp1 ffmpeg gcc \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN useradd -m -u 1000 user
10
+ USER user
11
+ ENV PATH="/home/user/.local/bin:$PATH"
12
+
13
+ WORKDIR /app
14
+
15
+ COPY --chown=user requirements.txt .
16
+ RUN pip install --no-cache-dir --timeout=300 --retries=5 -r requirements.txt
17
+ RUN pip uninstall -y opencv-python || true
18
+ RUN pip install --no-cache-dir --timeout=300 --force-reinstall opencv-python-headless>=4.10.0
19
+
20
+ COPY --chown=user . .
21
+
22
+ CMD ["python", "-m", "gunicorn", "backend.main:app", "-w", "1", "-k", "uvicorn.workers.UvicornWorker", "--timeout", "600", "--bind", "0.0.0.0:7860"]
backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Backend package for the FastAPI emotion recognition service."""
backend/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (249 Bytes). View file
 
backend/__pycache__/main.cpython-314.pyc ADDED
Binary file (43.9 kB). View file
 
backend/backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Backend package for the FastAPI emotion recognition service."""
backend/backend/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (249 Bytes). View file
 
backend/backend/__pycache__/main.cpython-314.pyc ADDED
Binary file (43.9 kB). View file
 
backend/backend/main.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI backend for multimodal (facial + speech) emotion inference."""
2
+
3
+ from fastapi import FastAPI, File, UploadFile
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import JSONResponse
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ import librosa
10
+ import base64
11
+ from PIL import Image, ImageOps
12
+ from io import BytesIO
13
+ from pathlib import Path
14
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoFeatureExtractor, AutoModelForAudioClassification
15
+ from huggingface_hub import hf_hub_download
16
+ import tempfile
17
+ import os
18
+ import logging
19
+ from threading import Lock
20
+ from dotenv import load_dotenv
21
+
22
+ try:
23
+ from facenet_pytorch import MTCNN # type: ignore[import-not-found]
24
+ except Exception:
25
+ MTCNN = None
26
+
27
+ # Load environment variables
28
+ load_dotenv()
29
+
30
+ # Configure logging
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format='[%(asctime)s] [%(levelname)s] %(message)s'
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Explainability helpers
38
+ from backend.services.explainability import generate_grad_cam, generate_audio_saliency
39
+
40
+ ENV = os.getenv("ENV", "development")
41
+ FRONTEND_URL = os.getenv(
42
+ "FRONTEND_URL",
43
+ os.getenv("REACT_APP_VERCEL_URL", "http://localhost:3000")
44
+ )
45
+ CORS_ORIGINS = os.getenv("CORS_ORIGINS", "")
46
+ USE_GPU = os.getenv("USE_GPU", "true").lower() == "true"
47
+ PRELOAD_MODELS = os.getenv("PRELOAD_MODELS", "false").lower() == "true"
48
+ ENABLE_FACE_ROTATION = os.getenv("ENABLE_FACE_ROTATION", "false").lower() == "true"
49
+ MAX_FACE_ROTATION_DEGREES = float(os.getenv("MAX_FACE_ROTATION_DEGREES", "8"))
50
+ HAAR_MIN_NEIGHBORS = int(os.getenv("HAAR_MIN_NEIGHBORS", "5"))
51
+ HAAR_MIN_SIZE = int(os.getenv("HAAR_MIN_SIZE", "40"))
52
+
53
+ app = FastAPI(title="Multi-Modal Emotion Recognition API", version="2.0.0")
54
+
55
+ # Configure CORS based on environment
56
+ if ENV == "production":
57
+ if CORS_ORIGINS.strip():
58
+ allowed_origins = [origin.strip() for origin in CORS_ORIGINS.split(",") if origin.strip()]
59
+ else:
60
+ allowed_origins = [FRONTEND_URL]
61
+ else:
62
+ allowed_origins = ["*"]
63
+
64
+ app.add_middleware(
65
+ CORSMiddleware,
66
+ allow_origins=allowed_origins,
67
+ allow_credentials=True,
68
+ allow_methods=["*"],
69
+ allow_headers=["*"],
70
+ )
71
+
72
+ logger.info(f"CORS enabled for: {allowed_origins}")
73
+ logger.info(
74
+ "Face detection config: rotation=%s max_rotation=%.1f haar_min_neighbors=%d haar_min_size=%d",
75
+ ENABLE_FACE_ROTATION,
76
+ MAX_FACE_ROTATION_DEGREES,
77
+ HAAR_MIN_NEIGHBORS,
78
+ HAAR_MIN_SIZE,
79
+ )
80
+
81
+ # Runtime configuration
82
+ EMOTIONS_FACIAL = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
83
+ EMOTIONS_SPEECH = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'neutral', 'sad', 'surprised']
84
+ DEVICE = torch.device('cuda' if (torch.cuda.is_available() and USE_GPU) else 'cpu')
85
+ MAX_SPEECH_INFER_SECONDS = int(os.getenv('MAX_SPEECH_INFER_SECONDS', '15'))
86
+ MAX_SPEECH_XAI_SECONDS = int(os.getenv('MAX_SPEECH_XAI_SECONDS', '8'))
87
+ CONCORDANCE_SCORE_MAP = {
88
+ 'MATCH': 100,
89
+ 'PARTIAL': 65,
90
+ 'MISMATCH': 30,
91
+ 'UNKNOWN': 0,
92
+ }
93
+
94
+ # In-memory model state
95
+ vit_model = None
96
+ facial_processor = None
97
+ speech_model = None
98
+ speech_processor = None
99
+ facial_loaded = False
100
+ speech_loaded = False
101
+
102
+ _facial_model_lock = Lock()
103
+ _speech_model_lock = Lock()
104
+
105
+ # Paths — download from HuggingFace Hub
106
+ logger.info("Resolving model paths from HuggingFace Hub...")
107
+ FACIAL_MODEL_PATH = hf_hub_download(
108
+ repo_id="Nishvaraj/emotion-models",
109
+ filename="vit_emotion_model.pt"
110
+ )
111
+ SPEECH_MODEL_PATH = hf_hub_download(
112
+ repo_id="Nishvaraj/emotion-models",
113
+ filename="hubert_emotion_model.pt"
114
+ )
115
+ logger.info(f"Facial model path: {FACIAL_MODEL_PATH}")
116
+ logger.info(f"Speech model path: {SPEECH_MODEL_PATH}")
117
+
118
+
119
+ def _upload_suffix(filename: str, default_suffix: str) -> str:
120
+ # Preserve the original extension when the browser provides one, otherwise fall back to a safe default.
121
+ suffix = Path(filename or '').suffix.lower()
122
+ return suffix if suffix else default_suffix
123
+
124
+
125
+ def _calculate_concordance(facial_emotion, speech_emotion, facial_confidence, speech_confidence):
126
+ # Match/partial/mismatch is derived from whether both models agree and how confident they are.
127
+ if facial_emotion == speech_emotion:
128
+ # When the modalities agree, the average confidence controls the concordance band.
129
+ score = (facial_confidence + speech_confidence) / 2
130
+ if score > 0.7:
131
+ concordance = "MATCH"
132
+ elif score >= 0.4:
133
+ concordance = "PARTIAL"
134
+ else:
135
+ concordance = "MISMATCH"
136
+ else:
137
+ # Different emotions can never be a full match, so we score by how close the confidences are.
138
+ score = 1 - abs(facial_confidence - speech_confidence)
139
+ if score >= 0.5:
140
+ concordance = "PARTIAL"
141
+ else:
142
+ concordance = "MISMATCH"
143
+
144
+ concordance_score = round(score * 100)
145
+ return concordance, concordance_score
146
+
147
+
148
+ FACE_CASCADE = cv2.CascadeClassifier(
149
+ cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
150
+ )
151
+ MTCNN_DETECTOR = MTCNN(keep_all=False, device=DEVICE) if MTCNN is not None else None
152
+
153
+
154
+ def _encode_image_base64(image_array: np.ndarray) -> str:
155
+ image_pil = Image.fromarray(image_array.astype(np.uint8))
156
+ buf = BytesIO()
157
+ image_pil.save(buf, format='PNG')
158
+ return base64.b64encode(buf.getvalue()).decode()
159
+
160
+
161
+ def _detect_primary_face(image: Image.Image):
162
+ # Prefer MTCNN when available because it gives stronger boxes and landmark points.
163
+ if MTCNN_DETECTOR is not None:
164
+ try:
165
+ boxes, probs, points = MTCNN_DETECTOR.detect(image, landmarks=True)
166
+ if boxes is not None and len(boxes) > 0:
167
+ # Use the highest-probability detection when multiple faces appear.
168
+ best_idx = int(np.argmax(probs)) if probs is not None else 0
169
+ x1, y1, x2, y2 = boxes[best_idx]
170
+ # Convert from [x1,y1,x2,y2] to [x,y,w,h]
171
+ x, y, w, h = int(x1), int(y1), int(x2 - x1), int(y2 - y1)
172
+ return (x, y, w, h), (points[best_idx] if points is not None else None)
173
+ except Exception as e:
174
+ logger.debug(f"MTCNN face detection fallback: {e}")
175
+
176
+ # Haar cascade is the fallback path so the app still works without facenet-pytorch.
177
+ img_array = np.array(image)
178
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
179
+ faces = FACE_CASCADE.detectMultiScale(
180
+ gray,
181
+ scaleFactor=1.1,
182
+ minNeighbors=HAAR_MIN_NEIGHBORS,
183
+ minSize=(HAAR_MIN_SIZE, HAAR_MIN_SIZE)
184
+ )
185
+
186
+ if faces is None or len(faces) == 0:
187
+ return None, None
188
+ best_face = max(faces, key=lambda b: b[2] * b[3])
189
+ return tuple(int(v) for v in best_face), None
190
+
191
+
192
+ def _rotate_image_to_level(image: Image.Image, points) -> Image.Image:
193
+ if not ENABLE_FACE_ROTATION:
194
+ return image
195
+
196
+ if points is None:
197
+ return image
198
+
199
+ try:
200
+ # Estimate head tilt from the eye landmarks and keep the correction bounded.
201
+ left_eye, right_eye = points[0], points[1]
202
+ angle = np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))
203
+ if abs(angle) < 1.0:
204
+ return image
205
+ if abs(angle) > MAX_FACE_ROTATION_DEGREES:
206
+ logger.debug("Skipping face rotation due to large angle: %.2f", angle)
207
+ return image
208
+ center_x = image.width / 2
209
+ center_y = image.height / 2
210
+ return image.rotate(-angle, resample=Image.Resampling.BICUBIC, expand=True, center=(center_x, center_y), fillcolor=(0, 0, 0))
211
+ except Exception:
212
+ return image
213
+
214
+
215
+ def _crop_face_with_margin(image_array: np.ndarray, face_box, margin_ratio: float = 0.12):
216
+ # Expand the detected face slightly so the classifier keeps some surrounding context.
217
+ x, y, w, h = [int(v) for v in face_box]
218
+ h_img, w_img = image_array.shape[:2]
219
+ mx = int(w * margin_ratio)
220
+ my = int(h * margin_ratio)
221
+
222
+ x1 = max(0, x - mx)
223
+ y1 = max(0, y - my)
224
+ x2 = min(w_img, x + w + mx)
225
+ y2 = min(h_img, y + h + my)
226
+
227
+ return image_array[y1:y2, x1:x2], (x1, y1, x2 - x1, y2 - y1)
228
+
229
+
230
+ def _shrink_box(face_box, shrink_ratio: float = 0.12):
231
+ # Draw a tighter outline for annotation so the face box looks cleaner on the preview image.
232
+ x, y, w, h = [int(v) for v in face_box]
233
+ dx = int(w * shrink_ratio / 2)
234
+ dy = int(h * shrink_ratio / 2)
235
+ x1 = x + dx
236
+ y1 = y + dy
237
+ width = max(1, w - (dx * 2))
238
+ height = max(1, h - (dy * 2))
239
+ return x1, y1, width, height
240
+
241
+
242
+ def _trim_audio_window(audio: np.ndarray, sr: int, max_seconds: int) -> np.ndarray:
243
+ # Long recordings are centered and clipped so inference stays fast and consistent.
244
+ if audio is None or sr <= 0:
245
+ return audio
246
+ max_len = int(sr * max_seconds)
247
+ if max_len <= 0 or len(audio) <= max_len:
248
+ return audio
249
+ start = (len(audio) - max_len) // 2
250
+ end = start + max_len
251
+ return audio[start:end]
252
+
253
+
254
+ logger.info(f"Device: {DEVICE}")
255
+ logger.info(f"Environment: {ENV}")
256
+
257
+ # ========== MODEL LOADING ==========
258
+
259
+ def load_facial_model():
260
+ """Load ViT model for facial emotion"""
261
+ global vit_model, facial_processor, facial_loaded
262
+ if vit_model is not None and facial_processor is not None:
263
+ facial_loaded = True
264
+ return True
265
+
266
+ with _facial_model_lock:
267
+ if vit_model is not None and facial_processor is not None:
268
+ facial_loaded = True
269
+ return True
270
+
271
+ try:
272
+ logger.info("Loading Facial Emotion Model (ViT)...")
273
+ # Keep the pretrained ViT backbone but swap in the emotion-class head size.
274
+ facial_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
275
+ vit_model = AutoModelForImageClassification.from_pretrained(
276
+ 'google/vit-base-patch16-224-in21k',
277
+ num_labels=len(EMOTIONS_FACIAL),
278
+ ignore_mismatched_sizes=True,
279
+ attn_implementation='eager'
280
+ )
281
+
282
+ # Load either a full checkpoint or a plain state_dict depending on how the file was saved.
283
+ checkpoint = torch.load(FACIAL_MODEL_PATH, map_location=DEVICE)
284
+ if 'model_state_dict' in checkpoint:
285
+ vit_model.load_state_dict(checkpoint['model_state_dict'])
286
+ else:
287
+ vit_model.load_state_dict(checkpoint)
288
+ logger.info("✓ Loaded ViT checkpoint")
289
+
290
+ vit_model = vit_model.to(DEVICE)
291
+ vit_model.eval()
292
+ facial_loaded = True
293
+ logger.info("✓ Facial model ready")
294
+ return True
295
+ except Exception as e:
296
+ facial_loaded = False
297
+ logger.error(f"❌ Error loading facial model: {e}")
298
+ return False
299
+
300
+
301
+ def load_speech_model():
302
+ """Load HuBERT model for speech emotion"""
303
+ global speech_model, speech_processor, speech_loaded
304
+ if speech_model is not None and speech_processor is not None:
305
+ speech_loaded = True
306
+ return True
307
+
308
+ with _speech_model_lock:
309
+ if speech_model is not None and speech_processor is not None:
310
+ speech_loaded = True
311
+ return True
312
+
313
+ try:
314
+ logger.info("Loading Speech Emotion Model (HuBERT)...")
315
+ # Match the pretrained audio backbone to the project-specific emotion label set.
316
+ speech_processor = AutoFeatureExtractor.from_pretrained('facebook/hubert-large-ls960-ft')
317
+ speech_model = AutoModelForAudioClassification.from_pretrained(
318
+ 'facebook/hubert-large-ls960-ft',
319
+ num_labels=len(EMOTIONS_SPEECH),
320
+ ignore_mismatched_sizes=True
321
+ )
322
+
323
+ # Support both checkpoint formats used across training experiments.
324
+ checkpoint = torch.load(SPEECH_MODEL_PATH, map_location=DEVICE)
325
+ if 'model_state_dict' in checkpoint:
326
+ speech_model.load_state_dict(checkpoint['model_state_dict'])
327
+ else:
328
+ speech_model.load_state_dict(checkpoint)
329
+ logger.info("✓ Loaded HuBERT checkpoint")
330
+
331
+ speech_model = speech_model.to(DEVICE)
332
+ speech_model.eval()
333
+ speech_loaded = True
334
+ logger.info("✓ Speech model ready")
335
+ return True
336
+ except Exception as e:
337
+ speech_loaded = False
338
+ logger.error(f"❌ Error loading speech model: {e}")
339
+ return False
340
+
341
+
342
+ def ensure_facial_model_loaded() -> bool:
343
+ if vit_model is not None and facial_processor is not None:
344
+ return True
345
+ return load_facial_model()
346
+
347
+
348
+ def ensure_speech_model_loaded() -> bool:
349
+ if speech_model is not None and speech_processor is not None:
350
+ return True
351
+ return load_speech_model()
352
+
353
+
354
+ # Optional eager loading for environments that prefer warm startup.
355
+ if PRELOAD_MODELS:
356
+ facial_loaded = load_facial_model()
357
+ speech_loaded = load_speech_model()
358
+
359
+ # ========== VIDEO PROCESSOR ==========
360
+
361
+ class VideoProcessor:
362
+ @staticmethod
363
+ def extract_frames_and_audio(video_path: str, fps_sample: int = 5):
364
+ """Extract frames and audio from video"""
365
+ frames = []
366
+ cap = cv2.VideoCapture(video_path)
367
+
368
+ if not cap.isOpened():
369
+ raise ValueError(f"Cannot open video: {video_path}")
370
+
371
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
372
+ fps = cap.get(cv2.CAP_PROP_FPS)
373
+ if fps <= 0 or fps > 120:
374
+ fps = 30.0
375
+
376
+ frame_count = 0
377
+ while cap.isOpened():
378
+ ret, frame = cap.read()
379
+ if not ret:
380
+ break
381
+
382
+ if frame_count % fps_sample == 0:
383
+ # Sample every Nth frame so we analyze representative facial expressions without processing the full video.
384
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
385
+ frames.append(Image.fromarray(frame_rgb))
386
+
387
+ frame_count += 1
388
+
389
+ cap.release()
390
+
391
+ # librosa reads the audio track directly from the same file, giving us a single mono stream for speech inference.
392
+ audio, sr = librosa.load(video_path, sr=16000, mono=True)
393
+
394
+ return frames, audio, sr, fps
395
+
396
+ # ========== PREDICTION FUNCTIONS ==========
397
+
398
+ def predict_facial_emotion(image: Image.Image, generate_explainability: bool = False):
399
+ """Predict emotion from image"""
400
+ try:
401
+ if not ensure_facial_model_loaded():
402
+ return None
403
+
404
+ # Normalize EXIF orientation first so mobile uploads and camera captures behave consistently.
405
+ image = ImageOps.exif_transpose(image).convert('RGB')
406
+
407
+ # Detect the most likely face before deciding whether to crop or rotate the input.
408
+ detected = _detect_primary_face(image)
409
+ face_box, face_points = detected if isinstance(detected, tuple) else (None, None)
410
+
411
+ # If we have eye landmarks, try a small rotation pass to correct head tilt.
412
+ rotated_image = _rotate_image_to_level(image, face_points)
413
+ if rotated_image is not image:
414
+ rotated_detected = _detect_primary_face(rotated_image)
415
+ if isinstance(rotated_detected, tuple):
416
+ rotated_box, rotated_points = rotated_detected
417
+ if rotated_box is not None:
418
+ image = rotated_image
419
+ face_box = rotated_box
420
+ face_points = rotated_points
421
+
422
+ input_array = np.array(image)
423
+
424
+ model_image = image
425
+
426
+ # Crop to the detected face when possible so the classifier sees the most relevant region.
427
+ if face_box is not None:
428
+ face_crop, _ = _crop_face_with_margin(input_array, face_box)
429
+ if face_crop.size > 0:
430
+ model_image = Image.fromarray(face_crop)
431
+
432
+ # Draw the face box on the preview image to make the detection step visible to the user.
433
+ annotated = input_array.copy()
434
+ if face_box is not None:
435
+ x, y, w, h = _shrink_box(face_box, shrink_ratio=0.08)
436
+ cv2.rectangle(annotated, (x, y), (x + w, y + h), (255, 128, 0), 2)
437
+ cv2.putText(
438
+ annotated,
439
+ 'Face detected',
440
+ (x, max(20, y - 8)),
441
+ cv2.FONT_HERSHEY_SIMPLEX,
442
+ 0.6,
443
+ (255, 128, 0),
444
+ 2,
445
+ cv2.LINE_AA
446
+ )
447
+
448
+ inputs = facial_processor(model_image, return_tensors='pt').to(DEVICE)
449
+ with torch.no_grad():
450
+ outputs = vit_model(**inputs)
451
+ logits = outputs.logits.cpu().numpy()[0]
452
+ # Convert raw logits into probabilities for easier interpretation in the UI.
453
+ probs = torch.softmax(torch.from_numpy(logits), dim=0).numpy()
454
+
455
+ top_idx = np.argmax(probs)
456
+ result = {
457
+ "emotion": EMOTIONS_FACIAL[top_idx],
458
+ "confidence": float(probs[top_idx]),
459
+ "probabilities": {e: float(p) for e, p in zip(EMOTIONS_FACIAL, probs)},
460
+ "face_detected": face_box is not None,
461
+ "annotated_image": _encode_image_base64(annotated)
462
+ }
463
+
464
+ if face_box is not None:
465
+ x, y, w, h = [int(v) for v in face_box]
466
+ result["face_box"] = {"x": x, "y": y, "width": w, "height": h}
467
+
468
+ if generate_explainability:
469
+ # Explainability is optional because Grad-CAM adds compute cost.
470
+ result["explainability_status"] = {
471
+ "requested": True,
472
+ "generated": False,
473
+ "error": None
474
+ }
475
+ try:
476
+ original_base64, heatmap_base64 = generate_grad_cam(
477
+ model_image,
478
+ vit_model,
479
+ facial_processor,
480
+ top_idx,
481
+ EMOTIONS_FACIAL,
482
+ DEVICE
483
+ )
484
+ if original_base64:
485
+ result["original_image"] = original_base64
486
+ if heatmap_base64:
487
+ result["grad_cam"] = heatmap_base64
488
+ result["explainability_status"]["generated"] = True
489
+ else:
490
+ result["explainability_status"]["error"] = "Grad-CAM map returned empty output"
491
+ except Exception as e:
492
+ logger.warning(f"Could not generate Grad-CAM: {e}")
493
+ result["explainability_status"]["error"] = str(e)
494
+
495
+ return result
496
+ except Exception as e:
497
+ logger.error(f"Error predicting facial emotion: {e}")
498
+ return None
499
+
500
+ def predict_speech_emotion(audio: np.ndarray, sr: int = 16000, generate_explainability: bool = False):
501
+ """Predict emotion from audio"""
502
+ try:
503
+ if not ensure_speech_model_loaded():
504
+ return None
505
+
506
+ if sr != 16000:
507
+ # Resample every input to the model's expected sampling rate.
508
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
509
+
510
+ # Keep inference fast and stable for long recordings.
511
+ audio_for_infer = _trim_audio_window(audio, 16000, MAX_SPEECH_INFER_SECONDS)
512
+
513
+ inputs = speech_processor(audio_for_infer, sampling_rate=16000, return_tensors="pt", padding=True)
514
+ with torch.no_grad():
515
+ outputs = speech_model(inputs['input_values'].to(DEVICE))
516
+ logits = outputs.logits.cpu().numpy()[0]
517
+ # Softmax keeps the output distribution easy to display and compare.
518
+ probs = np.exp(logits) / np.sum(np.exp(logits))
519
+
520
+ top_idx = np.argmax(probs)
521
+ result = {
522
+ "emotion": EMOTIONS_SPEECH[top_idx],
523
+ "confidence": float(probs[top_idx]),
524
+ "probabilities": {e: float(p) for e, p in zip(EMOTIONS_SPEECH, probs)}
525
+ }
526
+
527
+ if generate_explainability:
528
+ # Saliency is computed on a shorter slice to avoid long XAI runs on large clips.
529
+ result["explainability_status"] = {
530
+ "requested": True,
531
+ "generated": False,
532
+ "error": None
533
+ }
534
+ try:
535
+ # Saliency on a shorter centered chunk avoids multi-minute stalls.
536
+ audio_for_xai = _trim_audio_window(audio_for_infer, 16000, MAX_SPEECH_XAI_SECONDS)
537
+ spec_base64, saliency_base64 = generate_audio_saliency(
538
+ audio_for_xai,
539
+ speech_model,
540
+ speech_processor,
541
+ top_idx,
542
+ EMOTIONS_SPEECH,
543
+ DEVICE,
544
+ sr=16000
545
+ )
546
+ if spec_base64:
547
+ result["waveform"] = spec_base64
548
+ if saliency_base64:
549
+ result["saliency"] = saliency_base64
550
+ result["explainability_status"]["generated"] = True
551
+ else:
552
+ result["explainability_status"]["error"] = "Audio saliency map returned empty output"
553
+ except Exception as e:
554
+ logger.warning(f"Could not generate audio saliency: {e}")
555
+ result["explainability_status"]["error"] = str(e)
556
+
557
+ return result
558
+ except Exception as e:
559
+ logger.error(f"Error predicting speech emotion: {e}")
560
+ return None
561
+
562
+ # ========== API ENDPOINTS ==========
563
+
564
+ @app.get("/")
565
+ async def root():
566
+ return {"message": "Multi-Modal Emotion Recognition API v2.0", "status": "active"}
567
+
568
+ @app.get("/health")
569
+ async def health():
570
+ facial_ready = vit_model is not None and facial_processor is not None
571
+ speech_ready = speech_model is not None and speech_processor is not None
572
+ return {
573
+ "status": "healthy",
574
+ "facial_model": facial_ready,
575
+ "speech_model": speech_ready,
576
+ "lazy_loading": not PRELOAD_MODELS,
577
+ "device": str(DEVICE)
578
+ }
579
+
580
+ @app.post("/api/predict/facial")
581
+ async def predict_facial(file: UploadFile = File(...), explain: bool = False):
582
+ """Predict emotion from image"""
583
+ try:
584
+ logger.info(f"Received file: {file.filename}, content_type: {file.content_type}")
585
+ contents = await file.read()
586
+ logger.info(f"File size: {len(contents)} bytes")
587
+ if len(contents) == 0:
588
+ return JSONResponse(status_code=400, content={"error": "Empty file received"})
589
+ image = ImageOps.exif_transpose(Image.open(BytesIO(contents))).convert('RGB')
590
+ result = predict_facial_emotion(image, generate_explainability=explain)
591
+ return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
592
+ except Exception as e:
593
+ logger.error(f"Error in predict_facial: {e}", exc_info=True)
594
+ return JSONResponse(status_code=400, content={"error": str(e)})
595
+
596
+ @app.post("/api/predict/speech")
597
+ async def predict_speech(file: UploadFile = File(...), explain: bool = False):
598
+ """Predict emotion from audio"""
599
+ try:
600
+ contents = await file.read()
601
+ suffix = _upload_suffix(file.filename, '.wav')
602
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
603
+ tmp.write(contents)
604
+ tmp_path = tmp.name
605
+
606
+ try:
607
+ audio, sr = librosa.load(tmp_path, sr=16000)
608
+ result = predict_speech_emotion(audio, sr, generate_explainability=explain)
609
+ return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
610
+ finally:
611
+ os.unlink(tmp_path)
612
+ except Exception as e:
613
+ return JSONResponse(status_code=400, content={"error": str(e)})
614
+
615
+ @app.post("/api/predict/combined")
616
+ async def predict_combined(image_file: UploadFile = File(...), audio_file: UploadFile = File(...), explain: bool = False):
617
+ """Predict emotions from both image and audio, then compare results"""
618
+ try:
619
+ image_contents = await image_file.read()
620
+ image = ImageOps.exif_transpose(Image.open(BytesIO(image_contents))).convert('RGB')
621
+ facial_result = predict_facial_emotion(image, generate_explainability=explain)
622
+
623
+ audio_contents = await audio_file.read()
624
+ audio_suffix = _upload_suffix(audio_file.filename, '.wav')
625
+ with tempfile.NamedTemporaryFile(suffix=audio_suffix, delete=False) as tmp:
626
+ tmp.write(audio_contents)
627
+ tmp_path = tmp.name
628
+
629
+ try:
630
+ audio, sr = librosa.load(tmp_path, sr=16000)
631
+ speech_result = predict_speech_emotion(audio, sr, generate_explainability=explain)
632
+ finally:
633
+ os.unlink(tmp_path)
634
+
635
+ facial_emotion = facial_result["emotion"] if facial_result else None
636
+ facial_confidence = facial_result["confidence"] if facial_result else 0.0
637
+
638
+ speech_emotion = speech_result["emotion"] if speech_result else None
639
+ speech_confidence = speech_result["confidence"] if speech_result else 0.0
640
+
641
+ concordance, concordance_score = _calculate_concordance(
642
+ facial_emotion,
643
+ speech_emotion,
644
+ facial_confidence,
645
+ speech_confidence,
646
+ )
647
+
648
+ # The combined label should prefer the more confident modality when both are present.
649
+ combined_emotion = None
650
+ combined_confidence = 0.0
651
+
652
+ if facial_emotion and speech_emotion:
653
+ if facial_confidence > speech_confidence:
654
+ combined_emotion = facial_emotion
655
+ combined_confidence = facial_confidence
656
+ else:
657
+ combined_emotion = speech_emotion
658
+ combined_confidence = speech_confidence
659
+ elif facial_emotion:
660
+ combined_emotion = facial_emotion
661
+ combined_confidence = facial_confidence
662
+ elif speech_emotion:
663
+ combined_emotion = speech_emotion
664
+ combined_confidence = speech_confidence
665
+
666
+ response = {
667
+ "success": True,
668
+ "facial_emotion": {
669
+ "emotion": facial_emotion or "unknown",
670
+ "confidence": float(facial_confidence),
671
+ "probabilities": facial_result["probabilities"] if facial_result else {},
672
+ "face_detected": facial_result.get("face_detected", False) if facial_result else False,
673
+ "face_box": facial_result.get("face_box") if facial_result else None,
674
+ "annotated_image": facial_result.get("annotated_image") if facial_result else None
675
+ },
676
+ "speech_emotion": {
677
+ "emotion": speech_emotion or "unknown",
678
+ "confidence": float(speech_confidence),
679
+ "probabilities": speech_result["probabilities"] if speech_result else {}
680
+ },
681
+ "combined_emotion": combined_emotion or "unknown",
682
+ "combined_confidence": float(combined_confidence),
683
+ "concordance": concordance,
684
+ "concordance_score": concordance_score,
685
+ "analysis": {
686
+ "match": concordance == "MATCH",
687
+ "agreement_details": f"Face: {facial_emotion} (conf: {facial_confidence:.2f}) | Voice: {speech_emotion} (conf: {speech_confidence:.2f})"
688
+ }
689
+ }
690
+
691
+ if explain:
692
+ # Keep the response shape stable even when one modality fails to generate XAI output.
693
+ explainability = {}
694
+ errors = []
695
+
696
+ facial_status = (facial_result or {}).get("explainability_status") or {
697
+ "requested": True,
698
+ "generated": False,
699
+ "error": "Facial explainability unavailable"
700
+ }
701
+ speech_status = (speech_result or {}).get("explainability_status") or {
702
+ "requested": True,
703
+ "generated": False,
704
+ "error": "Speech explainability unavailable"
705
+ }
706
+
707
+ if facial_result and facial_result.get("grad_cam"):
708
+ explainability["grad_cam"] = facial_result.get("grad_cam")
709
+ elif facial_status.get("error"):
710
+ errors.append(f"Facial: {facial_status.get('error')}")
711
+
712
+ if speech_result and speech_result.get("saliency"):
713
+ explainability["saliency"] = speech_result.get("saliency")
714
+ elif speech_status.get("error"):
715
+ errors.append(f"Speech: {speech_status.get('error')}")
716
+
717
+ if speech_result and speech_result.get("waveform"):
718
+ explainability["waveform"] = speech_result.get("waveform")
719
+
720
+ response["explainability_status"] = {
721
+ "requested": True,
722
+ "generated": bool(explainability),
723
+ "facial": facial_status,
724
+ "speech": speech_status,
725
+ "errors": errors
726
+ }
727
+
728
+ if explainability:
729
+ response["explainability"] = explainability
730
+
731
+ return response
732
+ except Exception as e:
733
+ return JSONResponse(status_code=400, content={"error": str(e)})
734
+
735
+ @app.post("/api/predict/video")
736
+ async def predict_video_emotion(file: UploadFile = File(...), explain: bool = False):
737
+ """Predict emotions from video (facial + speech)"""
738
+ try:
739
+ video_suffix = _upload_suffix(file.filename, '.mp4')
740
+ with tempfile.NamedTemporaryFile(suffix=video_suffix, delete=False) as tmp:
741
+ contents = await file.read()
742
+ tmp.write(contents)
743
+ tmp_path = tmp.name
744
+
745
+ try:
746
+ processor = VideoProcessor()
747
+ frames, audio, sr, fps = processor.extract_frames_and_audio(tmp_path, fps_sample=5)
748
+
749
+ facial_results = []
750
+ for frame in frames[:10]:
751
+ result = predict_facial_emotion(frame)
752
+ if result:
753
+ facial_results.append(result)
754
+
755
+ if facial_results:
756
+ facial_emotions = [r["emotion"] for r in facial_results]
757
+ facial_confidence = np.mean([r["confidence"] for r in facial_results])
758
+ facial_emotion = max(set(facial_emotions), key=facial_emotions.count)
759
+ facial_probs = {}
760
+ for emotion in EMOTIONS_FACIAL:
761
+ facial_probs[emotion] = float(np.mean([r["probabilities"].get(emotion, 0) for r in facial_results]))
762
+ else:
763
+ facial_emotion = "unknown"
764
+ facial_confidence = 0.0
765
+ facial_probs = {e: 0.0 for e in EMOTIONS_FACIAL}
766
+
767
+ speech_result = predict_speech_emotion(audio, sr)
768
+ speech_emotion = speech_result["emotion"] if speech_result else "unknown"
769
+ speech_confidence = float(speech_result["confidence"]) if speech_result else 0.0
770
+ concordance, concordance_score = _calculate_concordance(
771
+ facial_emotion,
772
+ speech_emotion,
773
+ facial_confidence,
774
+ speech_confidence,
775
+ )
776
+
777
+ response = {
778
+ "success": True,
779
+ "facial_emotion": {
780
+ "emotion": facial_emotion,
781
+ "confidence": float(facial_confidence),
782
+ "frames_analyzed": len(facial_results),
783
+ "probabilities": facial_probs
784
+ },
785
+ "speech_emotion": {
786
+ "emotion": speech_emotion,
787
+ "confidence": speech_confidence,
788
+ "probabilities": speech_result["probabilities"] if speech_result else {e: 0.0 for e in EMOTIONS_SPEECH}
789
+ },
790
+ "combined_emotion": facial_emotion if facial_confidence > 0.5 else (speech_result["emotion"] if speech_result else "unknown"),
791
+ "concordance": concordance,
792
+ "concordance_score": concordance_score,
793
+ "video_duration": float(len(audio) / sr),
794
+ "frames_processed": len(frames),
795
+ "fps": float(fps)
796
+ }
797
+
798
+ if explain:
799
+ explainability = {}
800
+ errors = []
801
+
802
+ facial_exp_status = {"requested": True, "generated": False, "error": None}
803
+ speech_exp_status = {"requested": True, "generated": False, "error": None}
804
+
805
+ if frames and facial_emotion != "unknown":
806
+ try:
807
+ # Run GradCAM on the best frame that predicted the aggregated facial_emotion
808
+ best_frame = None
809
+ best_result = None
810
+ best_conf = 0
811
+ for frame in frames[:10]:
812
+ r = predict_facial_emotion(frame)
813
+ # Find the frame that predicted the aggregated emotion with highest confidence
814
+ if r and r.get("emotion") == facial_emotion and r.get("confidence", 0) > best_conf:
815
+ best_conf = r["confidence"]
816
+ best_frame = frame
817
+ best_result = r
818
+
819
+ # If no frame predicted the aggregated emotion, use the first frame
820
+ if best_frame is None and frames:
821
+ best_frame = frames[0]
822
+ best_result = predict_facial_emotion(best_frame)
823
+
824
+ if best_frame is not None:
825
+ top_idx = EMOTIONS_FACIAL.index(facial_emotion) \
826
+ if facial_emotion in EMOTIONS_FACIAL else 0
827
+ # Crop face before passing to GradCAM
828
+ face_box, _ = _detect_primary_face(best_frame)
829
+ if face_box is not None:
830
+ frame_array = np.array(best_frame)
831
+ face_crop_array, _ = _crop_face_with_margin(frame_array, face_box)
832
+ gradcam_input = Image.fromarray(face_crop_array) if face_crop_array.size > 0 else best_frame
833
+ else:
834
+ gradcam_input = best_frame
835
+ orig_b64, heatmap_b64 = generate_grad_cam(
836
+ gradcam_input, vit_model, facial_processor,
837
+ top_idx, EMOTIONS_FACIAL, DEVICE
838
+ )
839
+ if heatmap_b64:
840
+ explainability["grad_cam"] = heatmap_b64
841
+ facial_exp_status["generated"] = True
842
+ else:
843
+ facial_exp_status["error"] = "GradCAM returned empty output"
844
+ except Exception as e:
845
+ facial_exp_status["error"] = str(e)
846
+ else:
847
+ facial_exp_status["error"] = "No valid frame prediction found for facial explainability"
848
+
849
+ if speech_result and speech_emotion != "unknown":
850
+ try:
851
+ top_idx = EMOTIONS_SPEECH.index(speech_emotion) \
852
+ if speech_emotion in EMOTIONS_SPEECH else 0
853
+ audio_for_xai = _trim_audio_window(audio, sr, max_seconds=MAX_SPEECH_XAI_SECONDS)
854
+ spec_b64, saliency_b64 = generate_audio_saliency(
855
+ audio_for_xai,
856
+ speech_model,
857
+ speech_processor,
858
+ top_idx,
859
+ EMOTIONS_SPEECH,
860
+ DEVICE,
861
+ sr=16000
862
+ )
863
+ if spec_b64:
864
+ explainability["waveform"] = spec_b64
865
+ if saliency_b64:
866
+ explainability["saliency"] = saliency_b64
867
+ speech_exp_status["generated"] = True
868
+ else:
869
+ speech_exp_status["error"] = "Audio saliency map returned empty output"
870
+ except Exception as e:
871
+ speech_exp_status["error"] = str(e)
872
+ else:
873
+ speech_exp_status["error"] = "No valid audio prediction found for explainability"
874
+
875
+ if facial_exp_status.get("error"):
876
+ errors.append(f"Facial: {facial_exp_status.get('error')}")
877
+ if speech_exp_status.get("error"):
878
+ errors.append(f"Speech: {speech_exp_status.get('error')}")
879
+
880
+ response["explainability_status"] = {
881
+ "requested": True,
882
+ "generated": bool(explainability),
883
+ "facial": facial_exp_status,
884
+ "speech": speech_exp_status,
885
+ "errors": errors
886
+ }
887
+
888
+ if explainability:
889
+ response["explainability"] = explainability
890
+
891
+ return response
892
+ finally:
893
+ os.unlink(tmp_path)
894
+ except Exception as e:
895
+ return JSONResponse(status_code=400, content={"error": str(e)})
896
+
897
+ @app.get("/api/emotions/facial")
898
+ async def get_facial_emotions():
899
+ return {"emotions": EMOTIONS_FACIAL}
900
+
901
+ @app.get("/api/emotions/speech")
902
+ async def get_speech_emotions():
903
+ return {"emotions": EMOTIONS_SPEECH}
904
+
905
+ @app.get("/api/models/status")
906
+ async def get_models_status():
907
+ facial_ready = vit_model is not None and facial_processor is not None
908
+ speech_ready = speech_model is not None and speech_processor is not None
909
+ return {
910
+ "facial": {"loaded": facial_ready, "accuracy": 0.7129, "emotions": len(EMOTIONS_FACIAL)},
911
+ "speech": {"loaded": speech_ready, "accuracy": 0.8750, "emotions": len(EMOTIONS_SPEECH)},
912
+ "lazy_loading": not PRELOAD_MODELS,
913
+ "device": str(DEVICE)
914
+ }
backend/backend/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Service-layer utilities for backend inference and explainability."""
backend/backend/services/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (263 Bytes). View file
 
backend/backend/services/__pycache__/explainability.cpython-314.pyc ADDED
Binary file (15 kB). View file
 
backend/backend/services/data_loader.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loaders used by the training and experimentation workflows."""
2
+
3
+ import os
4
+ import numpy as np
5
+ import cv2
6
+ import librosa
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms
10
+
11
+
12
+ # ==================== FACIAL DATASET ====================
13
+ class FER2013Dataset(Dataset):
14
+ """FER2013 facial emotion dataset loader."""
15
+
16
+ def __init__(self, root_dir: str, split: str = "train", transform=None):
17
+ """
18
+ Initialize FER2013 dataset.
19
+
20
+ Args:
21
+ root_dir: Root directory containing 'train' and 'test' folders
22
+ split: 'train' or 'test'
23
+ transform: Torchvision transforms to apply
24
+ """
25
+ self.root_dir = root_dir
26
+ self.split = split
27
+ self.transform = transform
28
+ self.emotions = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
29
+ self.emotion2idx = {e: i for i, e in enumerate(self.emotions)}
30
+
31
+ self.samples = []
32
+ self._load_samples()
33
+
34
+ def _load_samples(self):
35
+ """Load all image paths and labels."""
36
+ split_dir = os.path.join(self.root_dir, self.split)
37
+
38
+ for emotion in self.emotions:
39
+ emotion_dir = os.path.join(split_dir, emotion)
40
+ if not os.path.exists(emotion_dir):
41
+ continue
42
+
43
+ for img_file in os.listdir(emotion_dir):
44
+ if img_file.endswith(('.jpg', '.jpeg', '.png')):
45
+ img_path = os.path.join(emotion_dir, img_file)
46
+ self.samples.append((img_path, self.emotion2idx[emotion]))
47
+
48
+ def __len__(self):
49
+ return len(self.samples)
50
+
51
+ def __getitem__(self, idx):
52
+ img_path, label = self.samples[idx]
53
+
54
+ # Load image
55
+ image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
56
+ if image is None:
57
+ return torch.zeros(3, 224, 224), torch.tensor(label, dtype=torch.long)
58
+
59
+ # Convert to RGB (3 channels)
60
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
61
+
62
+ if self.transform:
63
+ image = self.transform(image)
64
+ else:
65
+ # Default transform
66
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
67
+
68
+ return image, torch.tensor(label, dtype=torch.long)
69
+
70
+
71
+ # ==================== AUDIO DATASET ====================
72
+ class RAVDESSDataset(Dataset):
73
+ """RAVDESS audio emotion dataset loader."""
74
+
75
+ def __init__(self, root_dir: str, n_mfcc: int = 13, target_sr: int = 22050):
76
+ """
77
+ Initialize RAVDESS dataset.
78
+
79
+ Args:
80
+ root_dir: Root directory containing audio files
81
+ n_mfcc: Number of MFCCs to extract
82
+ target_sr: Target sampling rate
83
+ """
84
+ self.root_dir = root_dir
85
+ self.n_mfcc = n_mfcc
86
+ self.target_sr = target_sr
87
+ self.emotion_map = {
88
+ '01': 'neutral',
89
+ '02': 'calm',
90
+ '03': 'happy',
91
+ '04': 'sad',
92
+ '05': 'angry',
93
+ '06': 'fear',
94
+ '07': 'disgust',
95
+ '08': 'surprise'
96
+ }
97
+ self.emotion2idx = {v: i for i, v in enumerate(set(self.emotion_map.values()))}
98
+
99
+ self.samples = []
100
+ self._load_samples()
101
+
102
+ def _load_samples(self):
103
+ """Load all audio file paths and labels."""
104
+ for file in os.listdir(self.root_dir):
105
+ if file.endswith('.wav'):
106
+ emotion_code = file.split('-')[2]
107
+ if emotion_code in self.emotion_map:
108
+ emotion = self.emotion_map[emotion_code]
109
+ audio_path = os.path.join(self.root_dir, file)
110
+ self.samples.append((audio_path, self.emotion2idx[emotion]))
111
+
112
+ def __len__(self):
113
+ return len(self.samples)
114
+
115
+ def __getitem__(self, idx):
116
+ audio_path, label = self.samples[idx]
117
+
118
+ try:
119
+ y, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
120
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)
121
+
122
+ # Normalize MFCC
123
+ mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
124
+
125
+ # Pad or truncate to fixed size (100 time steps)
126
+ if mfcc.shape[1] < 100:
127
+ mfcc = np.pad(mfcc, ((0, 0), (0, 100 - mfcc.shape[1])), mode='constant')
128
+ else:
129
+ mfcc = mfcc[:, :100]
130
+
131
+ return torch.from_numpy(mfcc).float(), torch.tensor(label, dtype=torch.long)
132
+ except Exception as e:
133
+ print(f"Error loading {audio_path}: {e}")
134
+ return torch.zeros(self.n_mfcc, 100), torch.tensor(label, dtype=torch.long)
135
+
136
+
137
+ # ==================== DATALOADER FACTORY ====================
138
+ def create_dataloaders(
139
+ fer2013_dir: str = None,
140
+ ravdess_dir: str = None,
141
+ batch_size: int = 32,
142
+ num_workers: int = 0,
143
+ img_size: int = 224
144
+ ) -> dict:
145
+ """
146
+ Create dataloaders for FER2013 and RAVDESS datasets.
147
+
148
+ Args:
149
+ fer2013_dir: Path to FER2013 dataset root
150
+ ravdess_dir: Path to RAVDESS dataset root
151
+ batch_size: Batch size for training
152
+ num_workers: Number of workers for data loading
153
+ img_size: Image size for FER2013
154
+
155
+ Returns:
156
+ Dictionary with dataloaders for each dataset
157
+ """
158
+ transform = transforms.Compose([
159
+ transforms.ToPILImage(),
160
+ transforms.Resize((img_size, img_size)),
161
+ transforms.ToTensor(),
162
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
163
+ std=[0.229, 0.224, 0.225])
164
+ ])
165
+
166
+ dataloaders = {}
167
+
168
+ if fer2013_dir and os.path.exists(fer2013_dir):
169
+ train_dataset = FER2013Dataset(fer2013_dir, split='train', transform=transform)
170
+ test_dataset = FER2013Dataset(fer2013_dir, split='test', transform=transform)
171
+
172
+ dataloaders['fer2013_train'] = DataLoader(
173
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
174
+ )
175
+ dataloaders['fer2013_test'] = DataLoader(
176
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
177
+ )
178
+
179
+ if ravdess_dir and os.path.exists(ravdess_dir):
180
+ audio_dataset = RAVDESSDataset(ravdess_dir)
181
+ dataloaders['ravdess'] = DataLoader(
182
+ audio_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
183
+ )
184
+
185
+ return dataloaders
backend/backend/services/explainability.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Explainability utilities for multimodal emotion recognition outputs."""
2
+ import os
3
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "0"
4
+ os.environ["QT_QPA_PLATFORM"] = "offscreen"
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ import cv2
9
+ import librosa
10
+ import librosa.display
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+ import matplotlib.pyplot as plt
14
+ from PIL import Image
15
+ from io import BytesIO
16
+ import base64
17
+
18
+ from pytorch_grad_cam import GradCAM, EigenCAM
19
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
20
+ from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
21
+
22
+
23
+ # ==================== MODEL WRAPPER ====================
24
+ class ViTLogitsWrapper(nn.Module):
25
+ def __init__(self, model):
26
+ super().__init__()
27
+ self.model = model
28
+
29
+ def forward(self, x):
30
+ # Grad-CAM expects a standard forward() that returns logits for the selected class.
31
+ return self.model(pixel_values=x).logits
32
+
33
+
34
+ # ==================== FACIAL EXPLAINABILITY ====================
35
+ def generate_grad_cam(image, model, processor, emotion_idx, emotions_list, device):
36
+ try:
37
+ img_rgb = np.array(image.convert('RGB'))
38
+ h, w = img_rgb.shape[:2]
39
+ img_pil = Image.fromarray(img_rgb)
40
+
41
+ inputs = processor(img_pil, return_tensors='pt').to(device)
42
+ input_tensor = inputs['pixel_values']
43
+
44
+ wrapped_model = ViTLogitsWrapper(model)
45
+ wrapped_model.eval()
46
+
47
+ # Try multiple layers because the last block can become too saturated for a usable heatmap.
48
+ layers_to_try = [
49
+ model.vit.encoder.layer[-1].layernorm_after,
50
+ model.vit.encoder.layer[-2].layernorm_after,
51
+ model.vit.encoder.layer[-3].layernorm_after,
52
+ ]
53
+
54
+ cam_map = None
55
+ method_used = None
56
+
57
+ for i, layer in enumerate(layers_to_try):
58
+ try:
59
+ cam = GradCAM(
60
+ model=wrapped_model,
61
+ target_layers=[layer],
62
+ reshape_transform=vit_reshape_transform,
63
+ )
64
+ targets = [ClassifierOutputTarget(emotion_idx)]
65
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
66
+ result = grayscale_cam[0]
67
+
68
+ # Reject degenerate maps so the UI never shows a blank explanation as if it were valid.
69
+ if result.max() > 0.01:
70
+ cam_map = result
71
+ method_used = f"GradCAM (encoder block {12 - (i+1)})"
72
+ break
73
+ else:
74
+ print(f"[explainability] layer[-{i+1}] all zeros, trying next")
75
+
76
+ except Exception as e:
77
+ print(f"[explainability] GradCAM layer[-{i+1}] failed: {e}")
78
+
79
+ # Final fallback: EigenCAM gives a stable PCA-based map when gradients are unhelpful.
80
+ if cam_map is None:
81
+ print("[explainability] All GradCAM layers zero, using EigenCAM")
82
+ try:
83
+ eigen = EigenCAM(
84
+ model=wrapped_model,
85
+ target_layers=[model.vit.encoder.layer[-1].layernorm_after],
86
+ reshape_transform=vit_reshape_transform,
87
+ )
88
+ grayscale_cam = eigen(input_tensor=input_tensor)
89
+ cam_map = grayscale_cam[0]
90
+ method_used = "EigenCAM"
91
+ except Exception as e:
92
+ print(f"[explainability] EigenCAM failed: {e}")
93
+ return None, None
94
+
95
+ print(f"[explainability] {method_used} — min={cam_map.min():.3f}, max={cam_map.max():.3f}")
96
+
97
+ # Upscale and smooth the heatmap so it overlays cleanly on the source image.
98
+ cam_resized = cv2.resize(cam_map.astype(np.float32), (w, h), interpolation=cv2.INTER_CUBIC)
99
+ cam_resized = cv2.GaussianBlur(cam_resized, (13, 13), 0)
100
+
101
+ c_min, c_max = cam_resized.min(), cam_resized.max()
102
+ if c_max > c_min:
103
+ cam_resized = (cam_resized - c_min) / (c_max - c_min)
104
+
105
+ # Build the colored overlay and blend only the most salient regions.
106
+ cam_uint8 = np.uint8(255 * cam_resized)
107
+ heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
108
+ heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
109
+
110
+ threshold = np.percentile(cam_resized, 70)
111
+ blend_mask = (cam_resized > threshold).astype(np.float32)
112
+ blend_mask = cv2.GaussianBlur(blend_mask, (31, 31), 0)[..., None]
113
+
114
+ blended = (
115
+ (1 - blend_mask * 0.65) * img_rgb.astype(np.float32)
116
+ + blend_mask * 0.65 * heatmap_rgb.astype(np.float32)
117
+ ).clip(0, 255).astype(np.uint8)
118
+
119
+ orig_buf = BytesIO()
120
+ Image.fromarray(img_rgb).save(orig_buf, format='PNG')
121
+ orig_b64 = base64.b64encode(orig_buf.getvalue()).decode()
122
+
123
+ blend_buf = BytesIO()
124
+ Image.fromarray(blended).save(blend_buf, format='PNG')
125
+ blend_b64 = base64.b64encode(blend_buf.getvalue()).decode()
126
+
127
+ return orig_b64, blend_b64
128
+
129
+ except Exception as e:
130
+ print(f"[explainability] GradCAM generation failed: {e}")
131
+ return None, None
132
+
133
+
134
+ # ==================== AUDIO EXPLAINABILITY ====================
135
+ def generate_audio_saliency(audio, model, processor, emotion_idx, emotions_list, device, sr=16000):
136
+ try:
137
+ if audio is None or len(audio) == 0:
138
+ raise ValueError("Audio input is empty")
139
+
140
+ # Sanitize the audio before passing it into the speech backbone.
141
+ audio = np.asarray(audio, dtype=np.float32)
142
+ audio = np.nan_to_num(audio)
143
+
144
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
145
+ input_values = inputs['input_values'].to(device)
146
+
147
+ input_values.requires_grad = True
148
+ model.zero_grad()
149
+
150
+ outputs = model(input_values)
151
+ score = outputs.logits[0, emotion_idx]
152
+ score.backward()
153
+
154
+ if input_values.grad is None:
155
+ raise RuntimeError("No gradients captured")
156
+
157
+ saliency = torch.abs(input_values.grad).cpu().detach().numpy()
158
+
159
+ if saliency.ndim == 3:
160
+ saliency = np.mean(saliency, axis=1)[0]
161
+ elif saliency.ndim == 2:
162
+ saliency = saliency[0]
163
+ saliency = saliency.reshape(-1).astype(np.float32)
164
+
165
+ if saliency.size > 11:
166
+ # Smooth the gradient spikes so the curve is readable in the plot.
167
+ kernel = np.ones(11, dtype=np.float32) / 11.0
168
+ saliency = np.convolve(saliency, kernel, mode='same')
169
+
170
+ s_min, s_max = saliency.min(), saliency.max()
171
+ if s_max > s_min:
172
+ saliency = (saliency - s_min) / (s_max - s_min)
173
+ else:
174
+ saliency = np.zeros_like(saliency)
175
+
176
+ # Build both the spectrogram and the saliency overlay for a side-by-side explanation.
177
+ S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
178
+ S_db = librosa.power_to_db(S, ref=np.max)
179
+
180
+ fig1, ax1 = plt.subplots(figsize=(10, 4), dpi=100)
181
+ img1 = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax1)
182
+ ax1.set_title(f'Audio Spectrogram — {emotions_list[emotion_idx]}')
183
+ fig1.colorbar(img1, ax=ax1, format='%+2.0f dB')
184
+ spec_buf = BytesIO()
185
+ fig1.savefig(spec_buf, format='PNG', bbox_inches='tight', dpi=100)
186
+ spec_b64 = base64.b64encode(spec_buf.getvalue()).decode()
187
+ plt.close(fig1)
188
+
189
+ fig2, (ax2, ax3) = plt.subplots(2, 1, figsize=(10, 5.5), dpi=100,
190
+ gridspec_kw={'height_ratios': [3, 1]}, sharex=False)
191
+
192
+ # Normalize the spectrogram so the saliency colors stay visible across different recordings.
193
+ S_norm = (S_db - S_db.min()) / max(S_db.max() - S_db.min(), 1e-8)
194
+ sal_resized = np.interp(np.linspace(0, 1, S_db.shape[1]),
195
+ np.linspace(0, 1, saliency.shape[0]), saliency)
196
+ sal_map = np.tile(sal_resized, (S_db.shape[0], 1))
197
+
198
+ ax2.imshow(S_norm, aspect='auto', origin='lower', cmap='viridis', interpolation='bilinear')
199
+ ax2.imshow(sal_map, aspect='auto', origin='lower', cmap='magma', alpha=0.6, interpolation='bilinear')
200
+ ax2.set_title(f'Audio Saliency — {emotions_list[emotion_idx]} (bright = important)')
201
+ ax2.set_ylabel('Mel Frequency')
202
+
203
+ # Highlight the strongest time steps to give the user a clear peak view.
204
+ peak_thr = np.percentile(sal_resized, 85)
205
+ x = np.arange(len(sal_resized))
206
+ ax3.plot(x, sal_resized, color='#f97316', linewidth=1.5)
207
+ ax3.fill_between(x, 0, sal_resized, where=sal_resized >= peak_thr, color='#ef4444', alpha=0.4)
208
+ ax3.axhline(peak_thr, color='#ef4444', linestyle='--', linewidth=1, alpha=0.8)
209
+ ax3.set_ylim(0, 1.05)
210
+ ax3.set_ylabel('Saliency')
211
+ ax3.set_xlabel('Time steps')
212
+ ax3.grid(alpha=0.2)
213
+
214
+ sal_buf = BytesIO()
215
+ fig2.tight_layout()
216
+ fig2.savefig(sal_buf, format='PNG', bbox_inches='tight', dpi=100)
217
+ sal_b64 = base64.b64encode(sal_buf.getvalue()).decode()
218
+ plt.close(fig2)
219
+
220
+ return spec_b64, sal_b64
221
+
222
+ except Exception as e:
223
+ print(f"[explainability] Audio saliency failed: {e}")
224
+ return None, None
225
+
226
+
227
+ # ==================== COMBINED VISUALIZATION ====================
228
+ def create_combined_visualization(grad_cam_base64, saliency_base64, facial_emotion, speech_emotion, concordance):
229
+ try:
230
+ # Use a soft status tint so the combined report communicates agreement at a glance.
231
+ bg_color = '#d4edda' if concordance == 'MATCH' else '#f8d7da'
232
+ html = f"""
233
+ <div style="display:flex;gap:20px;padding:20px;background:#f5f5f5;border-radius:10px;">
234
+ <div style="flex:1;">
235
+ <h3>Facial GradCAM — {facial_emotion}</h3>
236
+ <img src="data:image/png;base64,{grad_cam_base64}" style="width:100%;border-radius:8px;">
237
+ <p style="font-size:12px;color:#666;">Red/warm = regions that most influenced the {facial_emotion} prediction.</p>
238
+ </div>
239
+ <div style="flex:1;">
240
+ <h3>Speech Saliency — {speech_emotion}</h3>
241
+ <img src="data:image/png;base64,{saliency_base64}" style="width:100%;border-radius:8px;">
242
+ <p style="font-size:12px;color:#666;">Bright = time-frequency regions with strongest influence.</p>
243
+ </div>
244
+ </div>
245
+ <div style="margin-top:20px;padding:15px;background:{bg_color};border-radius:8px;text-align:center;">
246
+ <h4 style="margin:0;">Concordance: <strong>{concordance}</strong></h4>
247
+ </div>
248
+ """
249
+ return base64.b64encode(html.encode()).decode()
250
+ except Exception as e:
251
+ print(f"[explainability] Combined visualisation failed: {e}")
252
+ return None
backend/main.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI backend for multimodal (facial + speech) emotion inference."""
2
+
3
+ from fastapi import FastAPI, File, UploadFile
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import JSONResponse
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ import librosa
10
+ import base64
11
+ from PIL import Image, ImageOps
12
+ from io import BytesIO
13
+ from pathlib import Path
14
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoFeatureExtractor, AutoModelForAudioClassification
15
+ from huggingface_hub import hf_hub_download
16
+ import tempfile
17
+ import os
18
+ import logging
19
+ from threading import Lock
20
+ from dotenv import load_dotenv
21
+
22
+ try:
23
+ from facenet_pytorch import MTCNN # type: ignore[import-not-found]
24
+ except Exception:
25
+ MTCNN = None
26
+
27
+ # Load environment variables
28
+ load_dotenv()
29
+
30
+ # Configure logging
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format='[%(asctime)s] [%(levelname)s] %(message)s'
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Explainability helpers
38
+ from backend.services.explainability import generate_grad_cam, generate_audio_saliency
39
+
40
+ ENV = os.getenv("ENV", "development")
41
+ FRONTEND_URL = os.getenv(
42
+ "FRONTEND_URL",
43
+ os.getenv("REACT_APP_VERCEL_URL", "http://localhost:3000")
44
+ )
45
+ CORS_ORIGINS = os.getenv("CORS_ORIGINS", "")
46
+ USE_GPU = os.getenv("USE_GPU", "true").lower() == "true"
47
+ PRELOAD_MODELS = os.getenv("PRELOAD_MODELS", "false").lower() == "true"
48
+ ENABLE_FACE_ROTATION = os.getenv("ENABLE_FACE_ROTATION", "false").lower() == "true"
49
+ MAX_FACE_ROTATION_DEGREES = float(os.getenv("MAX_FACE_ROTATION_DEGREES", "8"))
50
+ HAAR_MIN_NEIGHBORS = int(os.getenv("HAAR_MIN_NEIGHBORS", "5"))
51
+ HAAR_MIN_SIZE = int(os.getenv("HAAR_MIN_SIZE", "40"))
52
+
53
+ app = FastAPI(title="Multi-Modal Emotion Recognition API", version="2.0.0")
54
+
55
+ # Configure CORS based on environment
56
+ if ENV == "production":
57
+ if CORS_ORIGINS.strip():
58
+ allowed_origins = [origin.strip() for origin in CORS_ORIGINS.split(",") if origin.strip()]
59
+ else:
60
+ allowed_origins = [FRONTEND_URL]
61
+ else:
62
+ allowed_origins = ["*"]
63
+
64
+ app.add_middleware(
65
+ CORSMiddleware,
66
+ allow_origins=allowed_origins,
67
+ allow_credentials=True,
68
+ allow_methods=["*"],
69
+ allow_headers=["*"],
70
+ )
71
+
72
+ logger.info(f"CORS enabled for: {allowed_origins}")
73
+ logger.info(
74
+ "Face detection config: rotation=%s max_rotation=%.1f haar_min_neighbors=%d haar_min_size=%d",
75
+ ENABLE_FACE_ROTATION,
76
+ MAX_FACE_ROTATION_DEGREES,
77
+ HAAR_MIN_NEIGHBORS,
78
+ HAAR_MIN_SIZE,
79
+ )
80
+
81
+ # Runtime configuration
82
+ EMOTIONS_FACIAL = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
83
+ EMOTIONS_SPEECH = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'neutral', 'sad', 'surprised']
84
+ DEVICE = torch.device('cuda' if (torch.cuda.is_available() and USE_GPU) else 'cpu')
85
+ MAX_SPEECH_INFER_SECONDS = int(os.getenv('MAX_SPEECH_INFER_SECONDS', '15'))
86
+ MAX_SPEECH_XAI_SECONDS = int(os.getenv('MAX_SPEECH_XAI_SECONDS', '8'))
87
+ CONCORDANCE_SCORE_MAP = {
88
+ 'MATCH': 100,
89
+ 'PARTIAL': 65,
90
+ 'MISMATCH': 30,
91
+ 'UNKNOWN': 0,
92
+ }
93
+
94
+ # In-memory model state
95
+ vit_model = None
96
+ facial_processor = None
97
+ speech_model = None
98
+ speech_processor = None
99
+ facial_loaded = False
100
+ speech_loaded = False
101
+
102
+ _facial_model_lock = Lock()
103
+ _speech_model_lock = Lock()
104
+
105
+ # Paths — download from HuggingFace Hub
106
+ logger.info("Resolving model paths from HuggingFace Hub...")
107
+ FACIAL_MODEL_PATH = hf_hub_download(
108
+ repo_id="Nishvaraj/emotion-models",
109
+ filename="vit_emotion_model.pt"
110
+ )
111
+ SPEECH_MODEL_PATH = hf_hub_download(
112
+ repo_id="Nishvaraj/emotion-models",
113
+ filename="hubert_emotion_model.pt"
114
+ )
115
+ logger.info(f"Facial model path: {FACIAL_MODEL_PATH}")
116
+ logger.info(f"Speech model path: {SPEECH_MODEL_PATH}")
117
+
118
+
119
+ def _upload_suffix(filename: str, default_suffix: str) -> str:
120
+ # Preserve the original extension when the browser provides one, otherwise fall back to a safe default.
121
+ suffix = Path(filename or '').suffix.lower()
122
+ return suffix if suffix else default_suffix
123
+
124
+
125
+ def _calculate_concordance(facial_emotion, speech_emotion, facial_confidence, speech_confidence):
126
+ # Match/partial/mismatch is derived from whether both models agree and how confident they are.
127
+ if facial_emotion == speech_emotion:
128
+ # When the modalities agree, the average confidence controls the concordance band.
129
+ score = (facial_confidence + speech_confidence) / 2
130
+ if score > 0.7:
131
+ concordance = "MATCH"
132
+ elif score >= 0.4:
133
+ concordance = "PARTIAL"
134
+ else:
135
+ concordance = "MISMATCH"
136
+ else:
137
+ # Different emotions can never be a full match, so we score by how close the confidences are.
138
+ score = 1 - abs(facial_confidence - speech_confidence)
139
+ if score >= 0.5:
140
+ concordance = "PARTIAL"
141
+ else:
142
+ concordance = "MISMATCH"
143
+
144
+ concordance_score = round(score * 100)
145
+ return concordance, concordance_score
146
+
147
+
148
+ FACE_CASCADE = cv2.CascadeClassifier(
149
+ cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
150
+ )
151
+ MTCNN_DETECTOR = MTCNN(keep_all=False, device=DEVICE) if MTCNN is not None else None
152
+
153
+
154
+ def _encode_image_base64(image_array: np.ndarray) -> str:
155
+ image_pil = Image.fromarray(image_array.astype(np.uint8))
156
+ buf = BytesIO()
157
+ image_pil.save(buf, format='PNG')
158
+ return base64.b64encode(buf.getvalue()).decode()
159
+
160
+
161
+ def _detect_primary_face(image: Image.Image):
162
+ # Prefer MTCNN when available because it gives stronger boxes and landmark points.
163
+ if MTCNN_DETECTOR is not None:
164
+ try:
165
+ boxes, probs, points = MTCNN_DETECTOR.detect(image, landmarks=True)
166
+ if boxes is not None and len(boxes) > 0:
167
+ # Use the highest-probability detection when multiple faces appear.
168
+ best_idx = int(np.argmax(probs)) if probs is not None else 0
169
+ x1, y1, x2, y2 = boxes[best_idx]
170
+ # Convert from [x1,y1,x2,y2] to [x,y,w,h]
171
+ x, y, w, h = int(x1), int(y1), int(x2 - x1), int(y2 - y1)
172
+ return (x, y, w, h), (points[best_idx] if points is not None else None)
173
+ except Exception as e:
174
+ logger.debug(f"MTCNN face detection fallback: {e}")
175
+
176
+ # Haar cascade is the fallback path so the app still works without facenet-pytorch.
177
+ img_array = np.array(image)
178
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
179
+ faces = FACE_CASCADE.detectMultiScale(
180
+ gray,
181
+ scaleFactor=1.1,
182
+ minNeighbors=HAAR_MIN_NEIGHBORS,
183
+ minSize=(HAAR_MIN_SIZE, HAAR_MIN_SIZE)
184
+ )
185
+
186
+ if faces is None or len(faces) == 0:
187
+ return None, None
188
+ best_face = max(faces, key=lambda b: b[2] * b[3])
189
+ return tuple(int(v) for v in best_face), None
190
+
191
+
192
+ def _rotate_image_to_level(image: Image.Image, points) -> Image.Image:
193
+ if not ENABLE_FACE_ROTATION:
194
+ return image
195
+
196
+ if points is None:
197
+ return image
198
+
199
+ try:
200
+ # Estimate head tilt from the eye landmarks and keep the correction bounded.
201
+ left_eye, right_eye = points[0], points[1]
202
+ angle = np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))
203
+ if abs(angle) < 1.0:
204
+ return image
205
+ if abs(angle) > MAX_FACE_ROTATION_DEGREES:
206
+ logger.debug("Skipping face rotation due to large angle: %.2f", angle)
207
+ return image
208
+ center_x = image.width / 2
209
+ center_y = image.height / 2
210
+ return image.rotate(-angle, resample=Image.Resampling.BICUBIC, expand=True, center=(center_x, center_y), fillcolor=(0, 0, 0))
211
+ except Exception:
212
+ return image
213
+
214
+
215
+ def _crop_face_with_margin(image_array: np.ndarray, face_box, margin_ratio: float = 0.12):
216
+ # Expand the detected face slightly so the classifier keeps some surrounding context.
217
+ x, y, w, h = [int(v) for v in face_box]
218
+ h_img, w_img = image_array.shape[:2]
219
+ mx = int(w * margin_ratio)
220
+ my = int(h * margin_ratio)
221
+
222
+ x1 = max(0, x - mx)
223
+ y1 = max(0, y - my)
224
+ x2 = min(w_img, x + w + mx)
225
+ y2 = min(h_img, y + h + my)
226
+
227
+ return image_array[y1:y2, x1:x2], (x1, y1, x2 - x1, y2 - y1)
228
+
229
+
230
+ def _shrink_box(face_box, shrink_ratio: float = 0.12):
231
+ # Draw a tighter outline for annotation so the face box looks cleaner on the preview image.
232
+ x, y, w, h = [int(v) for v in face_box]
233
+ dx = int(w * shrink_ratio / 2)
234
+ dy = int(h * shrink_ratio / 2)
235
+ x1 = x + dx
236
+ y1 = y + dy
237
+ width = max(1, w - (dx * 2))
238
+ height = max(1, h - (dy * 2))
239
+ return x1, y1, width, height
240
+
241
+
242
+ def _trim_audio_window(audio: np.ndarray, sr: int, max_seconds: int) -> np.ndarray:
243
+ # Long recordings are centered and clipped so inference stays fast and consistent.
244
+ if audio is None or sr <= 0:
245
+ return audio
246
+ max_len = int(sr * max_seconds)
247
+ if max_len <= 0 or len(audio) <= max_len:
248
+ return audio
249
+ start = (len(audio) - max_len) // 2
250
+ end = start + max_len
251
+ return audio[start:end]
252
+
253
+
254
+ logger.info(f"Device: {DEVICE}")
255
+ logger.info(f"Environment: {ENV}")
256
+
257
+ # ========== MODEL LOADING ==========
258
+
259
+ def load_facial_model():
260
+ """Load ViT model for facial emotion"""
261
+ global vit_model, facial_processor, facial_loaded
262
+ if vit_model is not None and facial_processor is not None:
263
+ facial_loaded = True
264
+ return True
265
+
266
+ with _facial_model_lock:
267
+ if vit_model is not None and facial_processor is not None:
268
+ facial_loaded = True
269
+ return True
270
+
271
+ try:
272
+ logger.info("Loading Facial Emotion Model (ViT)...")
273
+ # Keep the pretrained ViT backbone but swap in the emotion-class head size.
274
+ facial_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
275
+ vit_model = AutoModelForImageClassification.from_pretrained(
276
+ 'google/vit-base-patch16-224-in21k',
277
+ num_labels=len(EMOTIONS_FACIAL),
278
+ ignore_mismatched_sizes=True,
279
+ attn_implementation='eager'
280
+ )
281
+
282
+ # Load either a full checkpoint or a plain state_dict depending on how the file was saved.
283
+ checkpoint = torch.load(FACIAL_MODEL_PATH, map_location=DEVICE)
284
+ if 'model_state_dict' in checkpoint:
285
+ vit_model.load_state_dict(checkpoint['model_state_dict'])
286
+ else:
287
+ vit_model.load_state_dict(checkpoint)
288
+ logger.info("✓ Loaded ViT checkpoint")
289
+
290
+ vit_model = vit_model.to(DEVICE)
291
+ vit_model.eval()
292
+ facial_loaded = True
293
+ logger.info("✓ Facial model ready")
294
+ return True
295
+ except Exception as e:
296
+ facial_loaded = False
297
+ logger.error(f"❌ Error loading facial model: {e}")
298
+ return False
299
+
300
+
301
+ def load_speech_model():
302
+ """Load HuBERT model for speech emotion"""
303
+ global speech_model, speech_processor, speech_loaded
304
+ if speech_model is not None and speech_processor is not None:
305
+ speech_loaded = True
306
+ return True
307
+
308
+ with _speech_model_lock:
309
+ if speech_model is not None and speech_processor is not None:
310
+ speech_loaded = True
311
+ return True
312
+
313
+ try:
314
+ logger.info("Loading Speech Emotion Model (HuBERT)...")
315
+ # Match the pretrained audio backbone to the project-specific emotion label set.
316
+ speech_processor = AutoFeatureExtractor.from_pretrained('facebook/hubert-large-ls960-ft')
317
+ speech_model = AutoModelForAudioClassification.from_pretrained(
318
+ 'facebook/hubert-large-ls960-ft',
319
+ num_labels=len(EMOTIONS_SPEECH),
320
+ ignore_mismatched_sizes=True
321
+ )
322
+
323
+ # Support both checkpoint formats used across training experiments.
324
+ checkpoint = torch.load(SPEECH_MODEL_PATH, map_location=DEVICE)
325
+ if 'model_state_dict' in checkpoint:
326
+ speech_model.load_state_dict(checkpoint['model_state_dict'])
327
+ else:
328
+ speech_model.load_state_dict(checkpoint)
329
+ logger.info("✓ Loaded HuBERT checkpoint")
330
+
331
+ speech_model = speech_model.to(DEVICE)
332
+ speech_model.eval()
333
+ speech_loaded = True
334
+ logger.info("✓ Speech model ready")
335
+ return True
336
+ except Exception as e:
337
+ speech_loaded = False
338
+ logger.error(f"❌ Error loading speech model: {e}")
339
+ return False
340
+
341
+
342
+ def ensure_facial_model_loaded() -> bool:
343
+ if vit_model is not None and facial_processor is not None:
344
+ return True
345
+ return load_facial_model()
346
+
347
+
348
+ def ensure_speech_model_loaded() -> bool:
349
+ if speech_model is not None and speech_processor is not None:
350
+ return True
351
+ return load_speech_model()
352
+
353
+
354
+ # Optional eager loading for environments that prefer warm startup.
355
+ if PRELOAD_MODELS:
356
+ facial_loaded = load_facial_model()
357
+ speech_loaded = load_speech_model()
358
+
359
+ # ========== VIDEO PROCESSOR ==========
360
+
361
+ class VideoProcessor:
362
+ @staticmethod
363
+ def extract_frames_and_audio(video_path: str, fps_sample: int = 5):
364
+ """Extract frames and audio from video"""
365
+ frames = []
366
+ cap = cv2.VideoCapture(video_path)
367
+
368
+ if not cap.isOpened():
369
+ raise ValueError(f"Cannot open video: {video_path}")
370
+
371
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
372
+ fps = cap.get(cv2.CAP_PROP_FPS)
373
+ if fps <= 0 or fps > 120:
374
+ fps = 30.0
375
+
376
+ frame_count = 0
377
+ while cap.isOpened():
378
+ ret, frame = cap.read()
379
+ if not ret:
380
+ break
381
+
382
+ if frame_count % fps_sample == 0:
383
+ # Sample every Nth frame so we analyze representative facial expressions without processing the full video.
384
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
385
+ frames.append(Image.fromarray(frame_rgb))
386
+
387
+ frame_count += 1
388
+
389
+ cap.release()
390
+
391
+ # librosa reads the audio track directly from the same file, giving us a single mono stream for speech inference.
392
+ audio, sr = librosa.load(video_path, sr=16000, mono=True)
393
+
394
+ return frames, audio, sr, fps
395
+
396
+ # ========== PREDICTION FUNCTIONS ==========
397
+
398
+ def predict_facial_emotion(image: Image.Image, generate_explainability: bool = False):
399
+ """Predict emotion from image"""
400
+ try:
401
+ if not ensure_facial_model_loaded():
402
+ return None
403
+
404
+ # Normalize EXIF orientation first so mobile uploads and camera captures behave consistently.
405
+ image = ImageOps.exif_transpose(image).convert('RGB')
406
+
407
+ # Detect the most likely face before deciding whether to crop or rotate the input.
408
+ detected = _detect_primary_face(image)
409
+ face_box, face_points = detected if isinstance(detected, tuple) else (None, None)
410
+
411
+ # If we have eye landmarks, try a small rotation pass to correct head tilt.
412
+ rotated_image = _rotate_image_to_level(image, face_points)
413
+ if rotated_image is not image:
414
+ rotated_detected = _detect_primary_face(rotated_image)
415
+ if isinstance(rotated_detected, tuple):
416
+ rotated_box, rotated_points = rotated_detected
417
+ if rotated_box is not None:
418
+ image = rotated_image
419
+ face_box = rotated_box
420
+ face_points = rotated_points
421
+
422
+ input_array = np.array(image)
423
+
424
+ model_image = image
425
+
426
+ # Crop to the detected face when possible so the classifier sees the most relevant region.
427
+ if face_box is not None:
428
+ face_crop, _ = _crop_face_with_margin(input_array, face_box)
429
+ if face_crop.size > 0:
430
+ model_image = Image.fromarray(face_crop)
431
+
432
+ # Draw the face box on the preview image to make the detection step visible to the user.
433
+ annotated = input_array.copy()
434
+ if face_box is not None:
435
+ x, y, w, h = _shrink_box(face_box, shrink_ratio=0.08)
436
+ cv2.rectangle(annotated, (x, y), (x + w, y + h), (255, 128, 0), 2)
437
+ cv2.putText(
438
+ annotated,
439
+ 'Face detected',
440
+ (x, max(20, y - 8)),
441
+ cv2.FONT_HERSHEY_SIMPLEX,
442
+ 0.6,
443
+ (255, 128, 0),
444
+ 2,
445
+ cv2.LINE_AA
446
+ )
447
+
448
+ inputs = facial_processor(model_image, return_tensors='pt').to(DEVICE)
449
+ with torch.no_grad():
450
+ outputs = vit_model(**inputs)
451
+ logits = outputs.logits.cpu().numpy()[0]
452
+ # Convert raw logits into probabilities for easier interpretation in the UI.
453
+ probs = torch.softmax(torch.from_numpy(logits), dim=0).numpy()
454
+
455
+ top_idx = np.argmax(probs)
456
+ result = {
457
+ "emotion": EMOTIONS_FACIAL[top_idx],
458
+ "confidence": float(probs[top_idx]),
459
+ "probabilities": {e: float(p) for e, p in zip(EMOTIONS_FACIAL, probs)},
460
+ "face_detected": face_box is not None,
461
+ "annotated_image": _encode_image_base64(annotated)
462
+ }
463
+
464
+ if face_box is not None:
465
+ x, y, w, h = [int(v) for v in face_box]
466
+ result["face_box"] = {"x": x, "y": y, "width": w, "height": h}
467
+
468
+ if generate_explainability:
469
+ # Explainability is optional because Grad-CAM adds compute cost.
470
+ result["explainability_status"] = {
471
+ "requested": True,
472
+ "generated": False,
473
+ "error": None
474
+ }
475
+ try:
476
+ original_base64, heatmap_base64 = generate_grad_cam(
477
+ model_image,
478
+ vit_model,
479
+ facial_processor,
480
+ top_idx,
481
+ EMOTIONS_FACIAL,
482
+ DEVICE
483
+ )
484
+ if original_base64:
485
+ result["original_image"] = original_base64
486
+ if heatmap_base64:
487
+ result["grad_cam"] = heatmap_base64
488
+ result["explainability_status"]["generated"] = True
489
+ else:
490
+ result["explainability_status"]["error"] = "Grad-CAM map returned empty output"
491
+ except Exception as e:
492
+ logger.warning(f"Could not generate Grad-CAM: {e}")
493
+ result["explainability_status"]["error"] = str(e)
494
+
495
+ return result
496
+ except Exception as e:
497
+ logger.error(f"Error predicting facial emotion: {e}")
498
+ return None
499
+
500
+ def predict_speech_emotion(audio: np.ndarray, sr: int = 16000, generate_explainability: bool = False):
501
+ """Predict emotion from audio"""
502
+ try:
503
+ if not ensure_speech_model_loaded():
504
+ return None
505
+
506
+ if sr != 16000:
507
+ # Resample every input to the model's expected sampling rate.
508
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
509
+
510
+ # Keep inference fast and stable for long recordings.
511
+ audio_for_infer = _trim_audio_window(audio, 16000, MAX_SPEECH_INFER_SECONDS)
512
+
513
+ inputs = speech_processor(audio_for_infer, sampling_rate=16000, return_tensors="pt", padding=True)
514
+ with torch.no_grad():
515
+ outputs = speech_model(inputs['input_values'].to(DEVICE))
516
+ logits = outputs.logits.cpu().numpy()[0]
517
+ # Softmax keeps the output distribution easy to display and compare.
518
+ probs = np.exp(logits) / np.sum(np.exp(logits))
519
+
520
+ top_idx = np.argmax(probs)
521
+ result = {
522
+ "emotion": EMOTIONS_SPEECH[top_idx],
523
+ "confidence": float(probs[top_idx]),
524
+ "probabilities": {e: float(p) for e, p in zip(EMOTIONS_SPEECH, probs)}
525
+ }
526
+
527
+ if generate_explainability:
528
+ # Saliency is computed on a shorter slice to avoid long XAI runs on large clips.
529
+ result["explainability_status"] = {
530
+ "requested": True,
531
+ "generated": False,
532
+ "error": None
533
+ }
534
+ try:
535
+ # Saliency on a shorter centered chunk avoids multi-minute stalls.
536
+ audio_for_xai = _trim_audio_window(audio_for_infer, 16000, MAX_SPEECH_XAI_SECONDS)
537
+ spec_base64, saliency_base64 = generate_audio_saliency(
538
+ audio_for_xai,
539
+ speech_model,
540
+ speech_processor,
541
+ top_idx,
542
+ EMOTIONS_SPEECH,
543
+ DEVICE,
544
+ sr=16000
545
+ )
546
+ if spec_base64:
547
+ result["waveform"] = spec_base64
548
+ if saliency_base64:
549
+ result["saliency"] = saliency_base64
550
+ result["explainability_status"]["generated"] = True
551
+ else:
552
+ result["explainability_status"]["error"] = "Audio saliency map returned empty output"
553
+ except Exception as e:
554
+ logger.warning(f"Could not generate audio saliency: {e}")
555
+ result["explainability_status"]["error"] = str(e)
556
+
557
+ return result
558
+ except Exception as e:
559
+ logger.error(f"Error predicting speech emotion: {e}")
560
+ return None
561
+
562
+ # ========== API ENDPOINTS ==========
563
+
564
+ @app.get("/")
565
+ async def root():
566
+ return {"message": "Multi-Modal Emotion Recognition API v2.0", "status": "active"}
567
+
568
+ @app.get("/health")
569
+ async def health():
570
+ facial_ready = vit_model is not None and facial_processor is not None
571
+ speech_ready = speech_model is not None and speech_processor is not None
572
+ return {
573
+ "status": "healthy",
574
+ "facial_model": facial_ready,
575
+ "speech_model": speech_ready,
576
+ "lazy_loading": not PRELOAD_MODELS,
577
+ "device": str(DEVICE)
578
+ }
579
+
580
+ @app.post("/api/predict/facial")
581
+ async def predict_facial(file: UploadFile = File(...), explain: bool = False):
582
+ """Predict emotion from image"""
583
+ try:
584
+ logger.info(f"Received file: {file.filename}, content_type: {file.content_type}")
585
+ contents = await file.read()
586
+ logger.info(f"File size: {len(contents)} bytes")
587
+ if len(contents) == 0:
588
+ return JSONResponse(status_code=400, content={"error": "Empty file received"})
589
+ image = ImageOps.exif_transpose(Image.open(BytesIO(contents))).convert('RGB')
590
+ result = predict_facial_emotion(image, generate_explainability=explain)
591
+ return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
592
+ except Exception as e:
593
+ logger.error(f"Error in predict_facial: {e}", exc_info=True)
594
+ return JSONResponse(status_code=400, content={"error": str(e)})
595
+
596
+ @app.post("/api/predict/speech")
597
+ async def predict_speech(file: UploadFile = File(...), explain: bool = False):
598
+ """Predict emotion from audio"""
599
+ try:
600
+ contents = await file.read()
601
+ suffix = _upload_suffix(file.filename, '.wav')
602
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
603
+ tmp.write(contents)
604
+ tmp_path = tmp.name
605
+
606
+ try:
607
+ audio, sr = librosa.load(tmp_path, sr=16000)
608
+ result = predict_speech_emotion(audio, sr, generate_explainability=explain)
609
+ return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
610
+ finally:
611
+ os.unlink(tmp_path)
612
+ except Exception as e:
613
+ return JSONResponse(status_code=400, content={"error": str(e)})
614
+
615
+ @app.post("/api/predict/combined")
616
+ async def predict_combined(image_file: UploadFile = File(...), audio_file: UploadFile = File(...), explain: bool = False):
617
+ """Predict emotions from both image and audio, then compare results"""
618
+ try:
619
+ image_contents = await image_file.read()
620
+ image = ImageOps.exif_transpose(Image.open(BytesIO(image_contents))).convert('RGB')
621
+ facial_result = predict_facial_emotion(image, generate_explainability=explain)
622
+
623
+ audio_contents = await audio_file.read()
624
+ audio_suffix = _upload_suffix(audio_file.filename, '.wav')
625
+ with tempfile.NamedTemporaryFile(suffix=audio_suffix, delete=False) as tmp:
626
+ tmp.write(audio_contents)
627
+ tmp_path = tmp.name
628
+
629
+ try:
630
+ audio, sr = librosa.load(tmp_path, sr=16000)
631
+ speech_result = predict_speech_emotion(audio, sr, generate_explainability=explain)
632
+ finally:
633
+ os.unlink(tmp_path)
634
+
635
+ facial_emotion = facial_result["emotion"] if facial_result else None
636
+ facial_confidence = facial_result["confidence"] if facial_result else 0.0
637
+
638
+ speech_emotion = speech_result["emotion"] if speech_result else None
639
+ speech_confidence = speech_result["confidence"] if speech_result else 0.0
640
+
641
+ concordance, concordance_score = _calculate_concordance(
642
+ facial_emotion,
643
+ speech_emotion,
644
+ facial_confidence,
645
+ speech_confidence,
646
+ )
647
+
648
+ # The combined label should prefer the more confident modality when both are present.
649
+ combined_emotion = None
650
+ combined_confidence = 0.0
651
+
652
+ if facial_emotion and speech_emotion:
653
+ if facial_confidence > speech_confidence:
654
+ combined_emotion = facial_emotion
655
+ combined_confidence = facial_confidence
656
+ else:
657
+ combined_emotion = speech_emotion
658
+ combined_confidence = speech_confidence
659
+ elif facial_emotion:
660
+ combined_emotion = facial_emotion
661
+ combined_confidence = facial_confidence
662
+ elif speech_emotion:
663
+ combined_emotion = speech_emotion
664
+ combined_confidence = speech_confidence
665
+
666
+ response = {
667
+ "success": True,
668
+ "facial_emotion": {
669
+ "emotion": facial_emotion or "unknown",
670
+ "confidence": float(facial_confidence),
671
+ "probabilities": facial_result["probabilities"] if facial_result else {},
672
+ "face_detected": facial_result.get("face_detected", False) if facial_result else False,
673
+ "face_box": facial_result.get("face_box") if facial_result else None,
674
+ "annotated_image": facial_result.get("annotated_image") if facial_result else None
675
+ },
676
+ "speech_emotion": {
677
+ "emotion": speech_emotion or "unknown",
678
+ "confidence": float(speech_confidence),
679
+ "probabilities": speech_result["probabilities"] if speech_result else {}
680
+ },
681
+ "combined_emotion": combined_emotion or "unknown",
682
+ "combined_confidence": float(combined_confidence),
683
+ "concordance": concordance,
684
+ "concordance_score": concordance_score,
685
+ "analysis": {
686
+ "match": concordance == "MATCH",
687
+ "agreement_details": f"Face: {facial_emotion} (conf: {facial_confidence:.2f}) | Voice: {speech_emotion} (conf: {speech_confidence:.2f})"
688
+ }
689
+ }
690
+
691
+ if explain:
692
+ # Keep the response shape stable even when one modality fails to generate XAI output.
693
+ explainability = {}
694
+ errors = []
695
+
696
+ facial_status = (facial_result or {}).get("explainability_status") or {
697
+ "requested": True,
698
+ "generated": False,
699
+ "error": "Facial explainability unavailable"
700
+ }
701
+ speech_status = (speech_result or {}).get("explainability_status") or {
702
+ "requested": True,
703
+ "generated": False,
704
+ "error": "Speech explainability unavailable"
705
+ }
706
+
707
+ if facial_result and facial_result.get("grad_cam"):
708
+ explainability["grad_cam"] = facial_result.get("grad_cam")
709
+ elif facial_status.get("error"):
710
+ errors.append(f"Facial: {facial_status.get('error')}")
711
+
712
+ if speech_result and speech_result.get("saliency"):
713
+ explainability["saliency"] = speech_result.get("saliency")
714
+ elif speech_status.get("error"):
715
+ errors.append(f"Speech: {speech_status.get('error')}")
716
+
717
+ if speech_result and speech_result.get("waveform"):
718
+ explainability["waveform"] = speech_result.get("waveform")
719
+
720
+ response["explainability_status"] = {
721
+ "requested": True,
722
+ "generated": bool(explainability),
723
+ "facial": facial_status,
724
+ "speech": speech_status,
725
+ "errors": errors
726
+ }
727
+
728
+ if explainability:
729
+ response["explainability"] = explainability
730
+
731
+ return response
732
+ except Exception as e:
733
+ return JSONResponse(status_code=400, content={"error": str(e)})
734
+
735
+ @app.post("/api/predict/video")
736
+ async def predict_video_emotion(file: UploadFile = File(...), explain: bool = False):
737
+ """Predict emotions from video (facial + speech)"""
738
+ try:
739
+ video_suffix = _upload_suffix(file.filename, '.mp4')
740
+ with tempfile.NamedTemporaryFile(suffix=video_suffix, delete=False) as tmp:
741
+ contents = await file.read()
742
+ tmp.write(contents)
743
+ tmp_path = tmp.name
744
+
745
+ try:
746
+ processor = VideoProcessor()
747
+ frames, audio, sr, fps = processor.extract_frames_and_audio(tmp_path, fps_sample=5)
748
+
749
+ facial_results = []
750
+ for frame in frames[:10]:
751
+ result = predict_facial_emotion(frame)
752
+ if result:
753
+ facial_results.append(result)
754
+
755
+ if facial_results:
756
+ facial_emotions = [r["emotion"] for r in facial_results]
757
+ facial_confidence = np.mean([r["confidence"] for r in facial_results])
758
+ facial_emotion = max(set(facial_emotions), key=facial_emotions.count)
759
+ facial_probs = {}
760
+ for emotion in EMOTIONS_FACIAL:
761
+ facial_probs[emotion] = float(np.mean([r["probabilities"].get(emotion, 0) for r in facial_results]))
762
+ else:
763
+ facial_emotion = "unknown"
764
+ facial_confidence = 0.0
765
+ facial_probs = {e: 0.0 for e in EMOTIONS_FACIAL}
766
+
767
+ speech_result = predict_speech_emotion(audio, sr)
768
+ speech_emotion = speech_result["emotion"] if speech_result else "unknown"
769
+ speech_confidence = float(speech_result["confidence"]) if speech_result else 0.0
770
+ concordance, concordance_score = _calculate_concordance(
771
+ facial_emotion,
772
+ speech_emotion,
773
+ facial_confidence,
774
+ speech_confidence,
775
+ )
776
+
777
+ response = {
778
+ "success": True,
779
+ "facial_emotion": {
780
+ "emotion": facial_emotion,
781
+ "confidence": float(facial_confidence),
782
+ "frames_analyzed": len(facial_results),
783
+ "probabilities": facial_probs
784
+ },
785
+ "speech_emotion": {
786
+ "emotion": speech_emotion,
787
+ "confidence": speech_confidence,
788
+ "probabilities": speech_result["probabilities"] if speech_result else {e: 0.0 for e in EMOTIONS_SPEECH}
789
+ },
790
+ "combined_emotion": facial_emotion if facial_confidence > 0.5 else (speech_result["emotion"] if speech_result else "unknown"),
791
+ "concordance": concordance,
792
+ "concordance_score": concordance_score,
793
+ "video_duration": float(len(audio) / sr),
794
+ "frames_processed": len(frames),
795
+ "fps": float(fps)
796
+ }
797
+
798
+ if explain:
799
+ explainability = {}
800
+ errors = []
801
+
802
+ facial_exp_status = {"requested": True, "generated": False, "error": None}
803
+ speech_exp_status = {"requested": True, "generated": False, "error": None}
804
+
805
+ if frames and facial_emotion != "unknown":
806
+ try:
807
+ # Run GradCAM on the best frame that predicted the aggregated facial_emotion
808
+ best_frame = None
809
+ best_result = None
810
+ best_conf = 0
811
+ for frame in frames[:10]:
812
+ r = predict_facial_emotion(frame)
813
+ # Find the frame that predicted the aggregated emotion with highest confidence
814
+ if r and r.get("emotion") == facial_emotion and r.get("confidence", 0) > best_conf:
815
+ best_conf = r["confidence"]
816
+ best_frame = frame
817
+ best_result = r
818
+
819
+ # If no frame predicted the aggregated emotion, use the first frame
820
+ if best_frame is None and frames:
821
+ best_frame = frames[0]
822
+ best_result = predict_facial_emotion(best_frame)
823
+
824
+ if best_frame is not None:
825
+ top_idx = EMOTIONS_FACIAL.index(facial_emotion) \
826
+ if facial_emotion in EMOTIONS_FACIAL else 0
827
+ # Crop face before passing to GradCAM
828
+ face_box, _ = _detect_primary_face(best_frame)
829
+ if face_box is not None:
830
+ frame_array = np.array(best_frame)
831
+ face_crop_array, _ = _crop_face_with_margin(frame_array, face_box)
832
+ gradcam_input = Image.fromarray(face_crop_array) if face_crop_array.size > 0 else best_frame
833
+ else:
834
+ gradcam_input = best_frame
835
+ orig_b64, heatmap_b64 = generate_grad_cam(
836
+ gradcam_input, vit_model, facial_processor,
837
+ top_idx, EMOTIONS_FACIAL, DEVICE
838
+ )
839
+ if heatmap_b64:
840
+ explainability["grad_cam"] = heatmap_b64
841
+ facial_exp_status["generated"] = True
842
+ else:
843
+ facial_exp_status["error"] = "GradCAM returned empty output"
844
+ except Exception as e:
845
+ facial_exp_status["error"] = str(e)
846
+ else:
847
+ facial_exp_status["error"] = "No valid frame prediction found for facial explainability"
848
+
849
+ if speech_result and speech_emotion != "unknown":
850
+ try:
851
+ top_idx = EMOTIONS_SPEECH.index(speech_emotion) \
852
+ if speech_emotion in EMOTIONS_SPEECH else 0
853
+ audio_for_xai = _trim_audio_window(audio, sr, max_seconds=MAX_SPEECH_XAI_SECONDS)
854
+ spec_b64, saliency_b64 = generate_audio_saliency(
855
+ audio_for_xai,
856
+ speech_model,
857
+ speech_processor,
858
+ top_idx,
859
+ EMOTIONS_SPEECH,
860
+ DEVICE,
861
+ sr=16000
862
+ )
863
+ if spec_b64:
864
+ explainability["waveform"] = spec_b64
865
+ if saliency_b64:
866
+ explainability["saliency"] = saliency_b64
867
+ speech_exp_status["generated"] = True
868
+ else:
869
+ speech_exp_status["error"] = "Audio saliency map returned empty output"
870
+ except Exception as e:
871
+ speech_exp_status["error"] = str(e)
872
+ else:
873
+ speech_exp_status["error"] = "No valid audio prediction found for explainability"
874
+
875
+ if facial_exp_status.get("error"):
876
+ errors.append(f"Facial: {facial_exp_status.get('error')}")
877
+ if speech_exp_status.get("error"):
878
+ errors.append(f"Speech: {speech_exp_status.get('error')}")
879
+
880
+ response["explainability_status"] = {
881
+ "requested": True,
882
+ "generated": bool(explainability),
883
+ "facial": facial_exp_status,
884
+ "speech": speech_exp_status,
885
+ "errors": errors
886
+ }
887
+
888
+ if explainability:
889
+ response["explainability"] = explainability
890
+
891
+ return response
892
+ finally:
893
+ os.unlink(tmp_path)
894
+ except Exception as e:
895
+ return JSONResponse(status_code=400, content={"error": str(e)})
896
+
897
+ @app.get("/api/emotions/facial")
898
+ async def get_facial_emotions():
899
+ return {"emotions": EMOTIONS_FACIAL}
900
+
901
+ @app.get("/api/emotions/speech")
902
+ async def get_speech_emotions():
903
+ return {"emotions": EMOTIONS_SPEECH}
904
+
905
+ @app.get("/api/models/status")
906
+ async def get_models_status():
907
+ facial_ready = vit_model is not None and facial_processor is not None
908
+ speech_ready = speech_model is not None and speech_processor is not None
909
+ return {
910
+ "facial": {"loaded": facial_ready, "accuracy": 0.7129, "emotions": len(EMOTIONS_FACIAL)},
911
+ "speech": {"loaded": speech_ready, "accuracy": 0.8750, "emotions": len(EMOTIONS_SPEECH)},
912
+ "lazy_loading": not PRELOAD_MODELS,
913
+ "device": str(DEVICE)
914
+ }
backend/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Service-layer utilities for backend inference and explainability."""
backend/services/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (263 Bytes). View file
 
backend/services/__pycache__/explainability.cpython-314.pyc ADDED
Binary file (15 kB). View file
 
backend/services/data_loader.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loaders used by the training and experimentation workflows."""
2
+
3
+ import os
4
+ import numpy as np
5
+ import cv2
6
+ import librosa
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms
10
+
11
+
12
+ # ==================== FACIAL DATASET ====================
13
+ class FER2013Dataset(Dataset):
14
+ """FER2013 facial emotion dataset loader."""
15
+
16
+ def __init__(self, root_dir: str, split: str = "train", transform=None):
17
+ """
18
+ Initialize FER2013 dataset.
19
+
20
+ Args:
21
+ root_dir: Root directory containing 'train' and 'test' folders
22
+ split: 'train' or 'test'
23
+ transform: Torchvision transforms to apply
24
+ """
25
+ self.root_dir = root_dir
26
+ self.split = split
27
+ self.transform = transform
28
+ self.emotions = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
29
+ self.emotion2idx = {e: i for i, e in enumerate(self.emotions)}
30
+
31
+ self.samples = []
32
+ self._load_samples()
33
+
34
+ def _load_samples(self):
35
+ """Load all image paths and labels."""
36
+ split_dir = os.path.join(self.root_dir, self.split)
37
+
38
+ for emotion in self.emotions:
39
+ emotion_dir = os.path.join(split_dir, emotion)
40
+ if not os.path.exists(emotion_dir):
41
+ continue
42
+
43
+ for img_file in os.listdir(emotion_dir):
44
+ if img_file.endswith(('.jpg', '.jpeg', '.png')):
45
+ img_path = os.path.join(emotion_dir, img_file)
46
+ self.samples.append((img_path, self.emotion2idx[emotion]))
47
+
48
+ def __len__(self):
49
+ return len(self.samples)
50
+
51
+ def __getitem__(self, idx):
52
+ img_path, label = self.samples[idx]
53
+
54
+ # Load image
55
+ image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
56
+ if image is None:
57
+ return torch.zeros(3, 224, 224), torch.tensor(label, dtype=torch.long)
58
+
59
+ # Convert to RGB (3 channels)
60
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
61
+
62
+ if self.transform:
63
+ image = self.transform(image)
64
+ else:
65
+ # Default transform
66
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
67
+
68
+ return image, torch.tensor(label, dtype=torch.long)
69
+
70
+
71
+ # ==================== AUDIO DATASET ====================
72
+ class RAVDESSDataset(Dataset):
73
+ """RAVDESS audio emotion dataset loader."""
74
+
75
+ def __init__(self, root_dir: str, n_mfcc: int = 13, target_sr: int = 22050):
76
+ """
77
+ Initialize RAVDESS dataset.
78
+
79
+ Args:
80
+ root_dir: Root directory containing audio files
81
+ n_mfcc: Number of MFCCs to extract
82
+ target_sr: Target sampling rate
83
+ """
84
+ self.root_dir = root_dir
85
+ self.n_mfcc = n_mfcc
86
+ self.target_sr = target_sr
87
+ self.emotion_map = {
88
+ '01': 'neutral',
89
+ '02': 'calm',
90
+ '03': 'happy',
91
+ '04': 'sad',
92
+ '05': 'angry',
93
+ '06': 'fear',
94
+ '07': 'disgust',
95
+ '08': 'surprise'
96
+ }
97
+ self.emotion2idx = {v: i for i, v in enumerate(set(self.emotion_map.values()))}
98
+
99
+ self.samples = []
100
+ self._load_samples()
101
+
102
+ def _load_samples(self):
103
+ """Load all audio file paths and labels."""
104
+ for file in os.listdir(self.root_dir):
105
+ if file.endswith('.wav'):
106
+ emotion_code = file.split('-')[2]
107
+ if emotion_code in self.emotion_map:
108
+ emotion = self.emotion_map[emotion_code]
109
+ audio_path = os.path.join(self.root_dir, file)
110
+ self.samples.append((audio_path, self.emotion2idx[emotion]))
111
+
112
+ def __len__(self):
113
+ return len(self.samples)
114
+
115
+ def __getitem__(self, idx):
116
+ audio_path, label = self.samples[idx]
117
+
118
+ try:
119
+ y, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
120
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)
121
+
122
+ # Normalize MFCC
123
+ mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
124
+
125
+ # Pad or truncate to fixed size (100 time steps)
126
+ if mfcc.shape[1] < 100:
127
+ mfcc = np.pad(mfcc, ((0, 0), (0, 100 - mfcc.shape[1])), mode='constant')
128
+ else:
129
+ mfcc = mfcc[:, :100]
130
+
131
+ return torch.from_numpy(mfcc).float(), torch.tensor(label, dtype=torch.long)
132
+ except Exception as e:
133
+ print(f"Error loading {audio_path}: {e}")
134
+ return torch.zeros(self.n_mfcc, 100), torch.tensor(label, dtype=torch.long)
135
+
136
+
137
+ # ==================== DATALOADER FACTORY ====================
138
+ def create_dataloaders(
139
+ fer2013_dir: str = None,
140
+ ravdess_dir: str = None,
141
+ batch_size: int = 32,
142
+ num_workers: int = 0,
143
+ img_size: int = 224
144
+ ) -> dict:
145
+ """
146
+ Create dataloaders for FER2013 and RAVDESS datasets.
147
+
148
+ Args:
149
+ fer2013_dir: Path to FER2013 dataset root
150
+ ravdess_dir: Path to RAVDESS dataset root
151
+ batch_size: Batch size for training
152
+ num_workers: Number of workers for data loading
153
+ img_size: Image size for FER2013
154
+
155
+ Returns:
156
+ Dictionary with dataloaders for each dataset
157
+ """
158
+ transform = transforms.Compose([
159
+ transforms.ToPILImage(),
160
+ transforms.Resize((img_size, img_size)),
161
+ transforms.ToTensor(),
162
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
163
+ std=[0.229, 0.224, 0.225])
164
+ ])
165
+
166
+ dataloaders = {}
167
+
168
+ if fer2013_dir and os.path.exists(fer2013_dir):
169
+ train_dataset = FER2013Dataset(fer2013_dir, split='train', transform=transform)
170
+ test_dataset = FER2013Dataset(fer2013_dir, split='test', transform=transform)
171
+
172
+ dataloaders['fer2013_train'] = DataLoader(
173
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
174
+ )
175
+ dataloaders['fer2013_test'] = DataLoader(
176
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
177
+ )
178
+
179
+ if ravdess_dir and os.path.exists(ravdess_dir):
180
+ audio_dataset = RAVDESSDataset(ravdess_dir)
181
+ dataloaders['ravdess'] = DataLoader(
182
+ audio_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
183
+ )
184
+
185
+ return dataloaders
backend/services/explainability.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Explainability utilities for multimodal emotion recognition outputs."""
2
+ import os
3
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "0"
4
+ os.environ["QT_QPA_PLATFORM"] = "offscreen"
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ import cv2
9
+ import librosa
10
+ import librosa.display
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+ import matplotlib.pyplot as plt
14
+ from PIL import Image
15
+ from io import BytesIO
16
+ import base64
17
+
18
+ from pytorch_grad_cam import GradCAM, EigenCAM
19
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
20
+ from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
21
+
22
+
23
+ # ==================== MODEL WRAPPER ====================
24
+ class ViTLogitsWrapper(nn.Module):
25
+ def __init__(self, model):
26
+ super().__init__()
27
+ self.model = model
28
+
29
+ def forward(self, x):
30
+ # Grad-CAM expects a standard forward() that returns logits for the selected class.
31
+ return self.model(pixel_values=x).logits
32
+
33
+
34
+ # ==================== FACIAL EXPLAINABILITY ====================
35
+ def generate_grad_cam(image, model, processor, emotion_idx, emotions_list, device):
36
+ try:
37
+ img_rgb = np.array(image.convert('RGB'))
38
+ h, w = img_rgb.shape[:2]
39
+ img_pil = Image.fromarray(img_rgb)
40
+
41
+ inputs = processor(img_pil, return_tensors='pt').to(device)
42
+ input_tensor = inputs['pixel_values']
43
+
44
+ wrapped_model = ViTLogitsWrapper(model)
45
+ wrapped_model.eval()
46
+
47
+ # Try multiple layers because the last block can become too saturated for a usable heatmap.
48
+ layers_to_try = [
49
+ model.vit.encoder.layer[-1].layernorm_after,
50
+ model.vit.encoder.layer[-2].layernorm_after,
51
+ model.vit.encoder.layer[-3].layernorm_after,
52
+ ]
53
+
54
+ cam_map = None
55
+ method_used = None
56
+
57
+ for i, layer in enumerate(layers_to_try):
58
+ try:
59
+ cam = GradCAM(
60
+ model=wrapped_model,
61
+ target_layers=[layer],
62
+ reshape_transform=vit_reshape_transform,
63
+ )
64
+ targets = [ClassifierOutputTarget(emotion_idx)]
65
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
66
+ result = grayscale_cam[0]
67
+
68
+ # Reject degenerate maps so the UI never shows a blank explanation as if it were valid.
69
+ if result.max() > 0.01:
70
+ cam_map = result
71
+ method_used = f"GradCAM (encoder block {12 - (i+1)})"
72
+ break
73
+ else:
74
+ print(f"[explainability] layer[-{i+1}] all zeros, trying next")
75
+
76
+ except Exception as e:
77
+ print(f"[explainability] GradCAM layer[-{i+1}] failed: {e}")
78
+
79
+ # Final fallback: EigenCAM gives a stable PCA-based map when gradients are unhelpful.
80
+ if cam_map is None:
81
+ print("[explainability] All GradCAM layers zero, using EigenCAM")
82
+ try:
83
+ eigen = EigenCAM(
84
+ model=wrapped_model,
85
+ target_layers=[model.vit.encoder.layer[-1].layernorm_after],
86
+ reshape_transform=vit_reshape_transform,
87
+ )
88
+ grayscale_cam = eigen(input_tensor=input_tensor)
89
+ cam_map = grayscale_cam[0]
90
+ method_used = "EigenCAM"
91
+ except Exception as e:
92
+ print(f"[explainability] EigenCAM failed: {e}")
93
+ return None, None
94
+
95
+ print(f"[explainability] {method_used} — min={cam_map.min():.3f}, max={cam_map.max():.3f}")
96
+
97
+ # Upscale and smooth the heatmap so it overlays cleanly on the source image.
98
+ cam_resized = cv2.resize(cam_map.astype(np.float32), (w, h), interpolation=cv2.INTER_CUBIC)
99
+ cam_resized = cv2.GaussianBlur(cam_resized, (13, 13), 0)
100
+
101
+ c_min, c_max = cam_resized.min(), cam_resized.max()
102
+ if c_max > c_min:
103
+ cam_resized = (cam_resized - c_min) / (c_max - c_min)
104
+
105
+ # Build the colored overlay and blend only the most salient regions.
106
+ cam_uint8 = np.uint8(255 * cam_resized)
107
+ heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
108
+ heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
109
+
110
+ threshold = np.percentile(cam_resized, 70)
111
+ blend_mask = (cam_resized > threshold).astype(np.float32)
112
+ blend_mask = cv2.GaussianBlur(blend_mask, (31, 31), 0)[..., None]
113
+
114
+ blended = (
115
+ (1 - blend_mask * 0.65) * img_rgb.astype(np.float32)
116
+ + blend_mask * 0.65 * heatmap_rgb.astype(np.float32)
117
+ ).clip(0, 255).astype(np.uint8)
118
+
119
+ orig_buf = BytesIO()
120
+ Image.fromarray(img_rgb).save(orig_buf, format='PNG')
121
+ orig_b64 = base64.b64encode(orig_buf.getvalue()).decode()
122
+
123
+ blend_buf = BytesIO()
124
+ Image.fromarray(blended).save(blend_buf, format='PNG')
125
+ blend_b64 = base64.b64encode(blend_buf.getvalue()).decode()
126
+
127
+ return orig_b64, blend_b64
128
+
129
+ except Exception as e:
130
+ print(f"[explainability] GradCAM generation failed: {e}")
131
+ return None, None
132
+
133
+
134
+ # ==================== AUDIO EXPLAINABILITY ====================
135
+ def generate_audio_saliency(audio, model, processor, emotion_idx, emotions_list, device, sr=16000):
136
+ try:
137
+ if audio is None or len(audio) == 0:
138
+ raise ValueError("Audio input is empty")
139
+
140
+ # Sanitize the audio before passing it into the speech backbone.
141
+ audio = np.asarray(audio, dtype=np.float32)
142
+ audio = np.nan_to_num(audio)
143
+
144
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
145
+ input_values = inputs['input_values'].to(device)
146
+
147
+ input_values.requires_grad = True
148
+ model.zero_grad()
149
+
150
+ outputs = model(input_values)
151
+ score = outputs.logits[0, emotion_idx]
152
+ score.backward()
153
+
154
+ if input_values.grad is None:
155
+ raise RuntimeError("No gradients captured")
156
+
157
+ saliency = torch.abs(input_values.grad).cpu().detach().numpy()
158
+
159
+ if saliency.ndim == 3:
160
+ saliency = np.mean(saliency, axis=1)[0]
161
+ elif saliency.ndim == 2:
162
+ saliency = saliency[0]
163
+ saliency = saliency.reshape(-1).astype(np.float32)
164
+
165
+ if saliency.size > 11:
166
+ # Smooth the gradient spikes so the curve is readable in the plot.
167
+ kernel = np.ones(11, dtype=np.float32) / 11.0
168
+ saliency = np.convolve(saliency, kernel, mode='same')
169
+
170
+ s_min, s_max = saliency.min(), saliency.max()
171
+ if s_max > s_min:
172
+ saliency = (saliency - s_min) / (s_max - s_min)
173
+ else:
174
+ saliency = np.zeros_like(saliency)
175
+
176
+ # Build both the spectrogram and the saliency overlay for a side-by-side explanation.
177
+ S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
178
+ S_db = librosa.power_to_db(S, ref=np.max)
179
+
180
+ fig1, ax1 = plt.subplots(figsize=(10, 4), dpi=100)
181
+ img1 = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax1)
182
+ ax1.set_title(f'Audio Spectrogram — {emotions_list[emotion_idx]}')
183
+ fig1.colorbar(img1, ax=ax1, format='%+2.0f dB')
184
+ spec_buf = BytesIO()
185
+ fig1.savefig(spec_buf, format='PNG', bbox_inches='tight', dpi=100)
186
+ spec_b64 = base64.b64encode(spec_buf.getvalue()).decode()
187
+ plt.close(fig1)
188
+
189
+ fig2, (ax2, ax3) = plt.subplots(2, 1, figsize=(10, 5.5), dpi=100,
190
+ gridspec_kw={'height_ratios': [3, 1]}, sharex=False)
191
+
192
+ # Normalize the spectrogram so the saliency colors stay visible across different recordings.
193
+ S_norm = (S_db - S_db.min()) / max(S_db.max() - S_db.min(), 1e-8)
194
+ sal_resized = np.interp(np.linspace(0, 1, S_db.shape[1]),
195
+ np.linspace(0, 1, saliency.shape[0]), saliency)
196
+ sal_map = np.tile(sal_resized, (S_db.shape[0], 1))
197
+
198
+ ax2.imshow(S_norm, aspect='auto', origin='lower', cmap='viridis', interpolation='bilinear')
199
+ ax2.imshow(sal_map, aspect='auto', origin='lower', cmap='magma', alpha=0.6, interpolation='bilinear')
200
+ ax2.set_title(f'Audio Saliency — {emotions_list[emotion_idx]} (bright = important)')
201
+ ax2.set_ylabel('Mel Frequency')
202
+
203
+ # Highlight the strongest time steps to give the user a clear peak view.
204
+ peak_thr = np.percentile(sal_resized, 85)
205
+ x = np.arange(len(sal_resized))
206
+ ax3.plot(x, sal_resized, color='#f97316', linewidth=1.5)
207
+ ax3.fill_between(x, 0, sal_resized, where=sal_resized >= peak_thr, color='#ef4444', alpha=0.4)
208
+ ax3.axhline(peak_thr, color='#ef4444', linestyle='--', linewidth=1, alpha=0.8)
209
+ ax3.set_ylim(0, 1.05)
210
+ ax3.set_ylabel('Saliency')
211
+ ax3.set_xlabel('Time steps')
212
+ ax3.grid(alpha=0.2)
213
+
214
+ sal_buf = BytesIO()
215
+ fig2.tight_layout()
216
+ fig2.savefig(sal_buf, format='PNG', bbox_inches='tight', dpi=100)
217
+ sal_b64 = base64.b64encode(sal_buf.getvalue()).decode()
218
+ plt.close(fig2)
219
+
220
+ return spec_b64, sal_b64
221
+
222
+ except Exception as e:
223
+ print(f"[explainability] Audio saliency failed: {e}")
224
+ return None, None
225
+
226
+
227
+ # ==================== COMBINED VISUALIZATION ====================
228
+ def create_combined_visualization(grad_cam_base64, saliency_base64, facial_emotion, speech_emotion, concordance):
229
+ try:
230
+ # Use a soft status tint so the combined report communicates agreement at a glance.
231
+ bg_color = '#d4edda' if concordance == 'MATCH' else '#f8d7da'
232
+ html = f"""
233
+ <div style="display:flex;gap:20px;padding:20px;background:#f5f5f5;border-radius:10px;">
234
+ <div style="flex:1;">
235
+ <h3>Facial GradCAM — {facial_emotion}</h3>
236
+ <img src="data:image/png;base64,{grad_cam_base64}" style="width:100%;border-radius:8px;">
237
+ <p style="font-size:12px;color:#666;">Red/warm = regions that most influenced the {facial_emotion} prediction.</p>
238
+ </div>
239
+ <div style="flex:1;">
240
+ <h3>Speech Saliency — {speech_emotion}</h3>
241
+ <img src="data:image/png;base64,{saliency_base64}" style="width:100%;border-radius:8px;">
242
+ <p style="font-size:12px;color:#666;">Bright = time-frequency regions with strongest influence.</p>
243
+ </div>
244
+ </div>
245
+ <div style="margin-top:20px;padding:15px;background:{bg_color};border-radius:8px;text-align:center;">
246
+ <h4 style="margin:0;">Concordance: <strong>{concordance}</strong></h4>
247
+ </div>
248
+ """
249
+ return base64.b64encode(html.encode()).decode()
250
+ except Exception as e:
251
+ print(f"[explainability] Combined visualisation failed: {e}")
252
+ return None
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web Framework
2
+ fastapi>=0.104.0
3
+ uvicorn[standard]>=0.24.0
4
+ gunicorn>=21.2.0
5
+ python-multipart>=0.0.6
6
+
7
+ # Core ML Libraries
8
+ torch>=2.6.0
9
+ torchvision>=0.20.0
10
+ torchaudio>=2.6.0
11
+ transformers>=4.46.0
12
+ timm>=1.0.0
13
+
14
+ # Face Detection
15
+ facenet-pytorch>=2.5.3
16
+
17
+ # Audio Processing
18
+ librosa>=0.10.2
19
+ soundfile>=0.12.1
20
+ pydub>=0.25.1
21
+
22
+ # Computer Vision
23
+ opencv-python-headless>=4.10.0
24
+ pillow>=10.4.0
25
+
26
+ # Deep Learning
27
+ numpy>=1.26.0
28
+ scipy>=1.14.0
29
+ scikit-learn>=1.5.0
30
+
31
+ # Visualization
32
+ matplotlib>=3.9.0
33
+ seaborn>=0.13.2
34
+
35
+ # Configuration
36
+ python-dotenv>=1.0.0
37
+
38
+ # Database
39
+ supabase>=2.0.0
40
+
41
+ # HuggingFace
42
+ huggingface_hub>=0.24.0
43
+ grad-cam
44
+
start.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ exec python -m gunicorn backend.main:app -w 1 -k uvicorn.workers.UvicornWorker --timeout 600 --bind 0.0.0.0:${PORT:-8080}