Spaces:
Sleeping
Sleeping
deploy MMER FastAPI backend on HF Spaces
Browse files- Dockerfile +22 -0
- backend/__init__.py +1 -0
- backend/__pycache__/__init__.cpython-314.pyc +0 -0
- backend/__pycache__/main.cpython-314.pyc +0 -0
- backend/backend/__init__.py +1 -0
- backend/backend/__pycache__/__init__.cpython-314.pyc +0 -0
- backend/backend/__pycache__/main.cpython-314.pyc +0 -0
- backend/backend/main.py +914 -0
- backend/backend/services/__init__.py +1 -0
- backend/backend/services/__pycache__/__init__.cpython-314.pyc +0 -0
- backend/backend/services/__pycache__/explainability.cpython-314.pyc +0 -0
- backend/backend/services/data_loader.py +185 -0
- backend/backend/services/explainability.py +252 -0
- backend/main.py +914 -0
- backend/services/__init__.py +1 -0
- backend/services/__pycache__/__init__.cpython-314.pyc +0 -0
- backend/services/__pycache__/explainability.cpython-314.pyc +0 -0
- backend/services/data_loader.py +185 -0
- backend/services/explainability.py +252 -0
- requirements.txt +44 -0
- start.sh +2 -0
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install -y \
|
| 4 |
+
libxcb1 libxcb-render0 libxcb-shm0 libxcb-xfixes0 \
|
| 5 |
+
libglib2.0-0 libsm6 libxext6 libxrender-dev \
|
| 6 |
+
libgomp1 ffmpeg gcc \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
RUN useradd -m -u 1000 user
|
| 10 |
+
USER user
|
| 11 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
COPY --chown=user requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir --timeout=300 --retries=5 -r requirements.txt
|
| 17 |
+
RUN pip uninstall -y opencv-python || true
|
| 18 |
+
RUN pip install --no-cache-dir --timeout=300 --force-reinstall opencv-python-headless>=4.10.0
|
| 19 |
+
|
| 20 |
+
COPY --chown=user . .
|
| 21 |
+
|
| 22 |
+
CMD ["python", "-m", "gunicorn", "backend.main:app", "-w", "1", "-k", "uvicorn.workers.UvicornWorker", "--timeout", "600", "--bind", "0.0.0.0:7860"]
|
backend/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Backend package for the FastAPI emotion recognition service."""
|
backend/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (249 Bytes). View file
|
|
|
backend/__pycache__/main.cpython-314.pyc
ADDED
|
Binary file (43.9 kB). View file
|
|
|
backend/backend/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Backend package for the FastAPI emotion recognition service."""
|
backend/backend/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (249 Bytes). View file
|
|
|
backend/backend/__pycache__/main.cpython-314.pyc
ADDED
|
Binary file (43.9 kB). View file
|
|
|
backend/backend/main.py
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend for multimodal (facial + speech) emotion inference."""
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI, File, UploadFile
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import librosa
|
| 10 |
+
import base64
|
| 11 |
+
from PIL import Image, ImageOps
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoFeatureExtractor, AutoModelForAudioClassification
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
import tempfile
|
| 17 |
+
import os
|
| 18 |
+
import logging
|
| 19 |
+
from threading import Lock
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from facenet_pytorch import MTCNN # type: ignore[import-not-found]
|
| 24 |
+
except Exception:
|
| 25 |
+
MTCNN = None
|
| 26 |
+
|
| 27 |
+
# Load environment variables
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(
|
| 32 |
+
level=logging.INFO,
|
| 33 |
+
format='[%(asctime)s] [%(levelname)s] %(message)s'
|
| 34 |
+
)
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
# Explainability helpers
|
| 38 |
+
from backend.services.explainability import generate_grad_cam, generate_audio_saliency
|
| 39 |
+
|
| 40 |
+
ENV = os.getenv("ENV", "development")
|
| 41 |
+
FRONTEND_URL = os.getenv(
|
| 42 |
+
"FRONTEND_URL",
|
| 43 |
+
os.getenv("REACT_APP_VERCEL_URL", "http://localhost:3000")
|
| 44 |
+
)
|
| 45 |
+
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "")
|
| 46 |
+
USE_GPU = os.getenv("USE_GPU", "true").lower() == "true"
|
| 47 |
+
PRELOAD_MODELS = os.getenv("PRELOAD_MODELS", "false").lower() == "true"
|
| 48 |
+
ENABLE_FACE_ROTATION = os.getenv("ENABLE_FACE_ROTATION", "false").lower() == "true"
|
| 49 |
+
MAX_FACE_ROTATION_DEGREES = float(os.getenv("MAX_FACE_ROTATION_DEGREES", "8"))
|
| 50 |
+
HAAR_MIN_NEIGHBORS = int(os.getenv("HAAR_MIN_NEIGHBORS", "5"))
|
| 51 |
+
HAAR_MIN_SIZE = int(os.getenv("HAAR_MIN_SIZE", "40"))
|
| 52 |
+
|
| 53 |
+
app = FastAPI(title="Multi-Modal Emotion Recognition API", version="2.0.0")
|
| 54 |
+
|
| 55 |
+
# Configure CORS based on environment
|
| 56 |
+
if ENV == "production":
|
| 57 |
+
if CORS_ORIGINS.strip():
|
| 58 |
+
allowed_origins = [origin.strip() for origin in CORS_ORIGINS.split(",") if origin.strip()]
|
| 59 |
+
else:
|
| 60 |
+
allowed_origins = [FRONTEND_URL]
|
| 61 |
+
else:
|
| 62 |
+
allowed_origins = ["*"]
|
| 63 |
+
|
| 64 |
+
app.add_middleware(
|
| 65 |
+
CORSMiddleware,
|
| 66 |
+
allow_origins=allowed_origins,
|
| 67 |
+
allow_credentials=True,
|
| 68 |
+
allow_methods=["*"],
|
| 69 |
+
allow_headers=["*"],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
logger.info(f"CORS enabled for: {allowed_origins}")
|
| 73 |
+
logger.info(
|
| 74 |
+
"Face detection config: rotation=%s max_rotation=%.1f haar_min_neighbors=%d haar_min_size=%d",
|
| 75 |
+
ENABLE_FACE_ROTATION,
|
| 76 |
+
MAX_FACE_ROTATION_DEGREES,
|
| 77 |
+
HAAR_MIN_NEIGHBORS,
|
| 78 |
+
HAAR_MIN_SIZE,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Runtime configuration
|
| 82 |
+
EMOTIONS_FACIAL = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
| 83 |
+
EMOTIONS_SPEECH = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'neutral', 'sad', 'surprised']
|
| 84 |
+
DEVICE = torch.device('cuda' if (torch.cuda.is_available() and USE_GPU) else 'cpu')
|
| 85 |
+
MAX_SPEECH_INFER_SECONDS = int(os.getenv('MAX_SPEECH_INFER_SECONDS', '15'))
|
| 86 |
+
MAX_SPEECH_XAI_SECONDS = int(os.getenv('MAX_SPEECH_XAI_SECONDS', '8'))
|
| 87 |
+
CONCORDANCE_SCORE_MAP = {
|
| 88 |
+
'MATCH': 100,
|
| 89 |
+
'PARTIAL': 65,
|
| 90 |
+
'MISMATCH': 30,
|
| 91 |
+
'UNKNOWN': 0,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# In-memory model state
|
| 95 |
+
vit_model = None
|
| 96 |
+
facial_processor = None
|
| 97 |
+
speech_model = None
|
| 98 |
+
speech_processor = None
|
| 99 |
+
facial_loaded = False
|
| 100 |
+
speech_loaded = False
|
| 101 |
+
|
| 102 |
+
_facial_model_lock = Lock()
|
| 103 |
+
_speech_model_lock = Lock()
|
| 104 |
+
|
| 105 |
+
# Paths — download from HuggingFace Hub
|
| 106 |
+
logger.info("Resolving model paths from HuggingFace Hub...")
|
| 107 |
+
FACIAL_MODEL_PATH = hf_hub_download(
|
| 108 |
+
repo_id="Nishvaraj/emotion-models",
|
| 109 |
+
filename="vit_emotion_model.pt"
|
| 110 |
+
)
|
| 111 |
+
SPEECH_MODEL_PATH = hf_hub_download(
|
| 112 |
+
repo_id="Nishvaraj/emotion-models",
|
| 113 |
+
filename="hubert_emotion_model.pt"
|
| 114 |
+
)
|
| 115 |
+
logger.info(f"Facial model path: {FACIAL_MODEL_PATH}")
|
| 116 |
+
logger.info(f"Speech model path: {SPEECH_MODEL_PATH}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _upload_suffix(filename: str, default_suffix: str) -> str:
|
| 120 |
+
# Preserve the original extension when the browser provides one, otherwise fall back to a safe default.
|
| 121 |
+
suffix = Path(filename or '').suffix.lower()
|
| 122 |
+
return suffix if suffix else default_suffix
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _calculate_concordance(facial_emotion, speech_emotion, facial_confidence, speech_confidence):
|
| 126 |
+
# Match/partial/mismatch is derived from whether both models agree and how confident they are.
|
| 127 |
+
if facial_emotion == speech_emotion:
|
| 128 |
+
# When the modalities agree, the average confidence controls the concordance band.
|
| 129 |
+
score = (facial_confidence + speech_confidence) / 2
|
| 130 |
+
if score > 0.7:
|
| 131 |
+
concordance = "MATCH"
|
| 132 |
+
elif score >= 0.4:
|
| 133 |
+
concordance = "PARTIAL"
|
| 134 |
+
else:
|
| 135 |
+
concordance = "MISMATCH"
|
| 136 |
+
else:
|
| 137 |
+
# Different emotions can never be a full match, so we score by how close the confidences are.
|
| 138 |
+
score = 1 - abs(facial_confidence - speech_confidence)
|
| 139 |
+
if score >= 0.5:
|
| 140 |
+
concordance = "PARTIAL"
|
| 141 |
+
else:
|
| 142 |
+
concordance = "MISMATCH"
|
| 143 |
+
|
| 144 |
+
concordance_score = round(score * 100)
|
| 145 |
+
return concordance, concordance_score
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
FACE_CASCADE = cv2.CascadeClassifier(
|
| 149 |
+
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 150 |
+
)
|
| 151 |
+
MTCNN_DETECTOR = MTCNN(keep_all=False, device=DEVICE) if MTCNN is not None else None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _encode_image_base64(image_array: np.ndarray) -> str:
|
| 155 |
+
image_pil = Image.fromarray(image_array.astype(np.uint8))
|
| 156 |
+
buf = BytesIO()
|
| 157 |
+
image_pil.save(buf, format='PNG')
|
| 158 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _detect_primary_face(image: Image.Image):
|
| 162 |
+
# Prefer MTCNN when available because it gives stronger boxes and landmark points.
|
| 163 |
+
if MTCNN_DETECTOR is not None:
|
| 164 |
+
try:
|
| 165 |
+
boxes, probs, points = MTCNN_DETECTOR.detect(image, landmarks=True)
|
| 166 |
+
if boxes is not None and len(boxes) > 0:
|
| 167 |
+
# Use the highest-probability detection when multiple faces appear.
|
| 168 |
+
best_idx = int(np.argmax(probs)) if probs is not None else 0
|
| 169 |
+
x1, y1, x2, y2 = boxes[best_idx]
|
| 170 |
+
# Convert from [x1,y1,x2,y2] to [x,y,w,h]
|
| 171 |
+
x, y, w, h = int(x1), int(y1), int(x2 - x1), int(y2 - y1)
|
| 172 |
+
return (x, y, w, h), (points[best_idx] if points is not None else None)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.debug(f"MTCNN face detection fallback: {e}")
|
| 175 |
+
|
| 176 |
+
# Haar cascade is the fallback path so the app still works without facenet-pytorch.
|
| 177 |
+
img_array = np.array(image)
|
| 178 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 179 |
+
faces = FACE_CASCADE.detectMultiScale(
|
| 180 |
+
gray,
|
| 181 |
+
scaleFactor=1.1,
|
| 182 |
+
minNeighbors=HAAR_MIN_NEIGHBORS,
|
| 183 |
+
minSize=(HAAR_MIN_SIZE, HAAR_MIN_SIZE)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if faces is None or len(faces) == 0:
|
| 187 |
+
return None, None
|
| 188 |
+
best_face = max(faces, key=lambda b: b[2] * b[3])
|
| 189 |
+
return tuple(int(v) for v in best_face), None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _rotate_image_to_level(image: Image.Image, points) -> Image.Image:
|
| 193 |
+
if not ENABLE_FACE_ROTATION:
|
| 194 |
+
return image
|
| 195 |
+
|
| 196 |
+
if points is None:
|
| 197 |
+
return image
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
# Estimate head tilt from the eye landmarks and keep the correction bounded.
|
| 201 |
+
left_eye, right_eye = points[0], points[1]
|
| 202 |
+
angle = np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))
|
| 203 |
+
if abs(angle) < 1.0:
|
| 204 |
+
return image
|
| 205 |
+
if abs(angle) > MAX_FACE_ROTATION_DEGREES:
|
| 206 |
+
logger.debug("Skipping face rotation due to large angle: %.2f", angle)
|
| 207 |
+
return image
|
| 208 |
+
center_x = image.width / 2
|
| 209 |
+
center_y = image.height / 2
|
| 210 |
+
return image.rotate(-angle, resample=Image.Resampling.BICUBIC, expand=True, center=(center_x, center_y), fillcolor=(0, 0, 0))
|
| 211 |
+
except Exception:
|
| 212 |
+
return image
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _crop_face_with_margin(image_array: np.ndarray, face_box, margin_ratio: float = 0.12):
|
| 216 |
+
# Expand the detected face slightly so the classifier keeps some surrounding context.
|
| 217 |
+
x, y, w, h = [int(v) for v in face_box]
|
| 218 |
+
h_img, w_img = image_array.shape[:2]
|
| 219 |
+
mx = int(w * margin_ratio)
|
| 220 |
+
my = int(h * margin_ratio)
|
| 221 |
+
|
| 222 |
+
x1 = max(0, x - mx)
|
| 223 |
+
y1 = max(0, y - my)
|
| 224 |
+
x2 = min(w_img, x + w + mx)
|
| 225 |
+
y2 = min(h_img, y + h + my)
|
| 226 |
+
|
| 227 |
+
return image_array[y1:y2, x1:x2], (x1, y1, x2 - x1, y2 - y1)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _shrink_box(face_box, shrink_ratio: float = 0.12):
|
| 231 |
+
# Draw a tighter outline for annotation so the face box looks cleaner on the preview image.
|
| 232 |
+
x, y, w, h = [int(v) for v in face_box]
|
| 233 |
+
dx = int(w * shrink_ratio / 2)
|
| 234 |
+
dy = int(h * shrink_ratio / 2)
|
| 235 |
+
x1 = x + dx
|
| 236 |
+
y1 = y + dy
|
| 237 |
+
width = max(1, w - (dx * 2))
|
| 238 |
+
height = max(1, h - (dy * 2))
|
| 239 |
+
return x1, y1, width, height
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _trim_audio_window(audio: np.ndarray, sr: int, max_seconds: int) -> np.ndarray:
|
| 243 |
+
# Long recordings are centered and clipped so inference stays fast and consistent.
|
| 244 |
+
if audio is None or sr <= 0:
|
| 245 |
+
return audio
|
| 246 |
+
max_len = int(sr * max_seconds)
|
| 247 |
+
if max_len <= 0 or len(audio) <= max_len:
|
| 248 |
+
return audio
|
| 249 |
+
start = (len(audio) - max_len) // 2
|
| 250 |
+
end = start + max_len
|
| 251 |
+
return audio[start:end]
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
logger.info(f"Device: {DEVICE}")
|
| 255 |
+
logger.info(f"Environment: {ENV}")
|
| 256 |
+
|
| 257 |
+
# ========== MODEL LOADING ==========
|
| 258 |
+
|
| 259 |
+
def load_facial_model():
|
| 260 |
+
"""Load ViT model for facial emotion"""
|
| 261 |
+
global vit_model, facial_processor, facial_loaded
|
| 262 |
+
if vit_model is not None and facial_processor is not None:
|
| 263 |
+
facial_loaded = True
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
with _facial_model_lock:
|
| 267 |
+
if vit_model is not None and facial_processor is not None:
|
| 268 |
+
facial_loaded = True
|
| 269 |
+
return True
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
logger.info("Loading Facial Emotion Model (ViT)...")
|
| 273 |
+
# Keep the pretrained ViT backbone but swap in the emotion-class head size.
|
| 274 |
+
facial_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
|
| 275 |
+
vit_model = AutoModelForImageClassification.from_pretrained(
|
| 276 |
+
'google/vit-base-patch16-224-in21k',
|
| 277 |
+
num_labels=len(EMOTIONS_FACIAL),
|
| 278 |
+
ignore_mismatched_sizes=True,
|
| 279 |
+
attn_implementation='eager'
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Load either a full checkpoint or a plain state_dict depending on how the file was saved.
|
| 283 |
+
checkpoint = torch.load(FACIAL_MODEL_PATH, map_location=DEVICE)
|
| 284 |
+
if 'model_state_dict' in checkpoint:
|
| 285 |
+
vit_model.load_state_dict(checkpoint['model_state_dict'])
|
| 286 |
+
else:
|
| 287 |
+
vit_model.load_state_dict(checkpoint)
|
| 288 |
+
logger.info("✓ Loaded ViT checkpoint")
|
| 289 |
+
|
| 290 |
+
vit_model = vit_model.to(DEVICE)
|
| 291 |
+
vit_model.eval()
|
| 292 |
+
facial_loaded = True
|
| 293 |
+
logger.info("✓ Facial model ready")
|
| 294 |
+
return True
|
| 295 |
+
except Exception as e:
|
| 296 |
+
facial_loaded = False
|
| 297 |
+
logger.error(f"❌ Error loading facial model: {e}")
|
| 298 |
+
return False
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def load_speech_model():
|
| 302 |
+
"""Load HuBERT model for speech emotion"""
|
| 303 |
+
global speech_model, speech_processor, speech_loaded
|
| 304 |
+
if speech_model is not None and speech_processor is not None:
|
| 305 |
+
speech_loaded = True
|
| 306 |
+
return True
|
| 307 |
+
|
| 308 |
+
with _speech_model_lock:
|
| 309 |
+
if speech_model is not None and speech_processor is not None:
|
| 310 |
+
speech_loaded = True
|
| 311 |
+
return True
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
logger.info("Loading Speech Emotion Model (HuBERT)...")
|
| 315 |
+
# Match the pretrained audio backbone to the project-specific emotion label set.
|
| 316 |
+
speech_processor = AutoFeatureExtractor.from_pretrained('facebook/hubert-large-ls960-ft')
|
| 317 |
+
speech_model = AutoModelForAudioClassification.from_pretrained(
|
| 318 |
+
'facebook/hubert-large-ls960-ft',
|
| 319 |
+
num_labels=len(EMOTIONS_SPEECH),
|
| 320 |
+
ignore_mismatched_sizes=True
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Support both checkpoint formats used across training experiments.
|
| 324 |
+
checkpoint = torch.load(SPEECH_MODEL_PATH, map_location=DEVICE)
|
| 325 |
+
if 'model_state_dict' in checkpoint:
|
| 326 |
+
speech_model.load_state_dict(checkpoint['model_state_dict'])
|
| 327 |
+
else:
|
| 328 |
+
speech_model.load_state_dict(checkpoint)
|
| 329 |
+
logger.info("✓ Loaded HuBERT checkpoint")
|
| 330 |
+
|
| 331 |
+
speech_model = speech_model.to(DEVICE)
|
| 332 |
+
speech_model.eval()
|
| 333 |
+
speech_loaded = True
|
| 334 |
+
logger.info("✓ Speech model ready")
|
| 335 |
+
return True
|
| 336 |
+
except Exception as e:
|
| 337 |
+
speech_loaded = False
|
| 338 |
+
logger.error(f"❌ Error loading speech model: {e}")
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def ensure_facial_model_loaded() -> bool:
|
| 343 |
+
if vit_model is not None and facial_processor is not None:
|
| 344 |
+
return True
|
| 345 |
+
return load_facial_model()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def ensure_speech_model_loaded() -> bool:
|
| 349 |
+
if speech_model is not None and speech_processor is not None:
|
| 350 |
+
return True
|
| 351 |
+
return load_speech_model()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# Optional eager loading for environments that prefer warm startup.
|
| 355 |
+
if PRELOAD_MODELS:
|
| 356 |
+
facial_loaded = load_facial_model()
|
| 357 |
+
speech_loaded = load_speech_model()
|
| 358 |
+
|
| 359 |
+
# ========== VIDEO PROCESSOR ==========
|
| 360 |
+
|
| 361 |
+
class VideoProcessor:
|
| 362 |
+
@staticmethod
|
| 363 |
+
def extract_frames_and_audio(video_path: str, fps_sample: int = 5):
|
| 364 |
+
"""Extract frames and audio from video"""
|
| 365 |
+
frames = []
|
| 366 |
+
cap = cv2.VideoCapture(video_path)
|
| 367 |
+
|
| 368 |
+
if not cap.isOpened():
|
| 369 |
+
raise ValueError(f"Cannot open video: {video_path}")
|
| 370 |
+
|
| 371 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 372 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 373 |
+
if fps <= 0 or fps > 120:
|
| 374 |
+
fps = 30.0
|
| 375 |
+
|
| 376 |
+
frame_count = 0
|
| 377 |
+
while cap.isOpened():
|
| 378 |
+
ret, frame = cap.read()
|
| 379 |
+
if not ret:
|
| 380 |
+
break
|
| 381 |
+
|
| 382 |
+
if frame_count % fps_sample == 0:
|
| 383 |
+
# Sample every Nth frame so we analyze representative facial expressions without processing the full video.
|
| 384 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 385 |
+
frames.append(Image.fromarray(frame_rgb))
|
| 386 |
+
|
| 387 |
+
frame_count += 1
|
| 388 |
+
|
| 389 |
+
cap.release()
|
| 390 |
+
|
| 391 |
+
# librosa reads the audio track directly from the same file, giving us a single mono stream for speech inference.
|
| 392 |
+
audio, sr = librosa.load(video_path, sr=16000, mono=True)
|
| 393 |
+
|
| 394 |
+
return frames, audio, sr, fps
|
| 395 |
+
|
| 396 |
+
# ========== PREDICTION FUNCTIONS ==========
|
| 397 |
+
|
| 398 |
+
def predict_facial_emotion(image: Image.Image, generate_explainability: bool = False):
|
| 399 |
+
"""Predict emotion from image"""
|
| 400 |
+
try:
|
| 401 |
+
if not ensure_facial_model_loaded():
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
# Normalize EXIF orientation first so mobile uploads and camera captures behave consistently.
|
| 405 |
+
image = ImageOps.exif_transpose(image).convert('RGB')
|
| 406 |
+
|
| 407 |
+
# Detect the most likely face before deciding whether to crop or rotate the input.
|
| 408 |
+
detected = _detect_primary_face(image)
|
| 409 |
+
face_box, face_points = detected if isinstance(detected, tuple) else (None, None)
|
| 410 |
+
|
| 411 |
+
# If we have eye landmarks, try a small rotation pass to correct head tilt.
|
| 412 |
+
rotated_image = _rotate_image_to_level(image, face_points)
|
| 413 |
+
if rotated_image is not image:
|
| 414 |
+
rotated_detected = _detect_primary_face(rotated_image)
|
| 415 |
+
if isinstance(rotated_detected, tuple):
|
| 416 |
+
rotated_box, rotated_points = rotated_detected
|
| 417 |
+
if rotated_box is not None:
|
| 418 |
+
image = rotated_image
|
| 419 |
+
face_box = rotated_box
|
| 420 |
+
face_points = rotated_points
|
| 421 |
+
|
| 422 |
+
input_array = np.array(image)
|
| 423 |
+
|
| 424 |
+
model_image = image
|
| 425 |
+
|
| 426 |
+
# Crop to the detected face when possible so the classifier sees the most relevant region.
|
| 427 |
+
if face_box is not None:
|
| 428 |
+
face_crop, _ = _crop_face_with_margin(input_array, face_box)
|
| 429 |
+
if face_crop.size > 0:
|
| 430 |
+
model_image = Image.fromarray(face_crop)
|
| 431 |
+
|
| 432 |
+
# Draw the face box on the preview image to make the detection step visible to the user.
|
| 433 |
+
annotated = input_array.copy()
|
| 434 |
+
if face_box is not None:
|
| 435 |
+
x, y, w, h = _shrink_box(face_box, shrink_ratio=0.08)
|
| 436 |
+
cv2.rectangle(annotated, (x, y), (x + w, y + h), (255, 128, 0), 2)
|
| 437 |
+
cv2.putText(
|
| 438 |
+
annotated,
|
| 439 |
+
'Face detected',
|
| 440 |
+
(x, max(20, y - 8)),
|
| 441 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 442 |
+
0.6,
|
| 443 |
+
(255, 128, 0),
|
| 444 |
+
2,
|
| 445 |
+
cv2.LINE_AA
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
inputs = facial_processor(model_image, return_tensors='pt').to(DEVICE)
|
| 449 |
+
with torch.no_grad():
|
| 450 |
+
outputs = vit_model(**inputs)
|
| 451 |
+
logits = outputs.logits.cpu().numpy()[0]
|
| 452 |
+
# Convert raw logits into probabilities for easier interpretation in the UI.
|
| 453 |
+
probs = torch.softmax(torch.from_numpy(logits), dim=0).numpy()
|
| 454 |
+
|
| 455 |
+
top_idx = np.argmax(probs)
|
| 456 |
+
result = {
|
| 457 |
+
"emotion": EMOTIONS_FACIAL[top_idx],
|
| 458 |
+
"confidence": float(probs[top_idx]),
|
| 459 |
+
"probabilities": {e: float(p) for e, p in zip(EMOTIONS_FACIAL, probs)},
|
| 460 |
+
"face_detected": face_box is not None,
|
| 461 |
+
"annotated_image": _encode_image_base64(annotated)
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
if face_box is not None:
|
| 465 |
+
x, y, w, h = [int(v) for v in face_box]
|
| 466 |
+
result["face_box"] = {"x": x, "y": y, "width": w, "height": h}
|
| 467 |
+
|
| 468 |
+
if generate_explainability:
|
| 469 |
+
# Explainability is optional because Grad-CAM adds compute cost.
|
| 470 |
+
result["explainability_status"] = {
|
| 471 |
+
"requested": True,
|
| 472 |
+
"generated": False,
|
| 473 |
+
"error": None
|
| 474 |
+
}
|
| 475 |
+
try:
|
| 476 |
+
original_base64, heatmap_base64 = generate_grad_cam(
|
| 477 |
+
model_image,
|
| 478 |
+
vit_model,
|
| 479 |
+
facial_processor,
|
| 480 |
+
top_idx,
|
| 481 |
+
EMOTIONS_FACIAL,
|
| 482 |
+
DEVICE
|
| 483 |
+
)
|
| 484 |
+
if original_base64:
|
| 485 |
+
result["original_image"] = original_base64
|
| 486 |
+
if heatmap_base64:
|
| 487 |
+
result["grad_cam"] = heatmap_base64
|
| 488 |
+
result["explainability_status"]["generated"] = True
|
| 489 |
+
else:
|
| 490 |
+
result["explainability_status"]["error"] = "Grad-CAM map returned empty output"
|
| 491 |
+
except Exception as e:
|
| 492 |
+
logger.warning(f"Could not generate Grad-CAM: {e}")
|
| 493 |
+
result["explainability_status"]["error"] = str(e)
|
| 494 |
+
|
| 495 |
+
return result
|
| 496 |
+
except Exception as e:
|
| 497 |
+
logger.error(f"Error predicting facial emotion: {e}")
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
def predict_speech_emotion(audio: np.ndarray, sr: int = 16000, generate_explainability: bool = False):
|
| 501 |
+
"""Predict emotion from audio"""
|
| 502 |
+
try:
|
| 503 |
+
if not ensure_speech_model_loaded():
|
| 504 |
+
return None
|
| 505 |
+
|
| 506 |
+
if sr != 16000:
|
| 507 |
+
# Resample every input to the model's expected sampling rate.
|
| 508 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
| 509 |
+
|
| 510 |
+
# Keep inference fast and stable for long recordings.
|
| 511 |
+
audio_for_infer = _trim_audio_window(audio, 16000, MAX_SPEECH_INFER_SECONDS)
|
| 512 |
+
|
| 513 |
+
inputs = speech_processor(audio_for_infer, sampling_rate=16000, return_tensors="pt", padding=True)
|
| 514 |
+
with torch.no_grad():
|
| 515 |
+
outputs = speech_model(inputs['input_values'].to(DEVICE))
|
| 516 |
+
logits = outputs.logits.cpu().numpy()[0]
|
| 517 |
+
# Softmax keeps the output distribution easy to display and compare.
|
| 518 |
+
probs = np.exp(logits) / np.sum(np.exp(logits))
|
| 519 |
+
|
| 520 |
+
top_idx = np.argmax(probs)
|
| 521 |
+
result = {
|
| 522 |
+
"emotion": EMOTIONS_SPEECH[top_idx],
|
| 523 |
+
"confidence": float(probs[top_idx]),
|
| 524 |
+
"probabilities": {e: float(p) for e, p in zip(EMOTIONS_SPEECH, probs)}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
if generate_explainability:
|
| 528 |
+
# Saliency is computed on a shorter slice to avoid long XAI runs on large clips.
|
| 529 |
+
result["explainability_status"] = {
|
| 530 |
+
"requested": True,
|
| 531 |
+
"generated": False,
|
| 532 |
+
"error": None
|
| 533 |
+
}
|
| 534 |
+
try:
|
| 535 |
+
# Saliency on a shorter centered chunk avoids multi-minute stalls.
|
| 536 |
+
audio_for_xai = _trim_audio_window(audio_for_infer, 16000, MAX_SPEECH_XAI_SECONDS)
|
| 537 |
+
spec_base64, saliency_base64 = generate_audio_saliency(
|
| 538 |
+
audio_for_xai,
|
| 539 |
+
speech_model,
|
| 540 |
+
speech_processor,
|
| 541 |
+
top_idx,
|
| 542 |
+
EMOTIONS_SPEECH,
|
| 543 |
+
DEVICE,
|
| 544 |
+
sr=16000
|
| 545 |
+
)
|
| 546 |
+
if spec_base64:
|
| 547 |
+
result["waveform"] = spec_base64
|
| 548 |
+
if saliency_base64:
|
| 549 |
+
result["saliency"] = saliency_base64
|
| 550 |
+
result["explainability_status"]["generated"] = True
|
| 551 |
+
else:
|
| 552 |
+
result["explainability_status"]["error"] = "Audio saliency map returned empty output"
|
| 553 |
+
except Exception as e:
|
| 554 |
+
logger.warning(f"Could not generate audio saliency: {e}")
|
| 555 |
+
result["explainability_status"]["error"] = str(e)
|
| 556 |
+
|
| 557 |
+
return result
|
| 558 |
+
except Exception as e:
|
| 559 |
+
logger.error(f"Error predicting speech emotion: {e}")
|
| 560 |
+
return None
|
| 561 |
+
|
| 562 |
+
# ========== API ENDPOINTS ==========
|
| 563 |
+
|
| 564 |
+
@app.get("/")
|
| 565 |
+
async def root():
|
| 566 |
+
return {"message": "Multi-Modal Emotion Recognition API v2.0", "status": "active"}
|
| 567 |
+
|
| 568 |
+
@app.get("/health")
|
| 569 |
+
async def health():
|
| 570 |
+
facial_ready = vit_model is not None and facial_processor is not None
|
| 571 |
+
speech_ready = speech_model is not None and speech_processor is not None
|
| 572 |
+
return {
|
| 573 |
+
"status": "healthy",
|
| 574 |
+
"facial_model": facial_ready,
|
| 575 |
+
"speech_model": speech_ready,
|
| 576 |
+
"lazy_loading": not PRELOAD_MODELS,
|
| 577 |
+
"device": str(DEVICE)
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
@app.post("/api/predict/facial")
|
| 581 |
+
async def predict_facial(file: UploadFile = File(...), explain: bool = False):
|
| 582 |
+
"""Predict emotion from image"""
|
| 583 |
+
try:
|
| 584 |
+
logger.info(f"Received file: {file.filename}, content_type: {file.content_type}")
|
| 585 |
+
contents = await file.read()
|
| 586 |
+
logger.info(f"File size: {len(contents)} bytes")
|
| 587 |
+
if len(contents) == 0:
|
| 588 |
+
return JSONResponse(status_code=400, content={"error": "Empty file received"})
|
| 589 |
+
image = ImageOps.exif_transpose(Image.open(BytesIO(contents))).convert('RGB')
|
| 590 |
+
result = predict_facial_emotion(image, generate_explainability=explain)
|
| 591 |
+
return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
|
| 592 |
+
except Exception as e:
|
| 593 |
+
logger.error(f"Error in predict_facial: {e}", exc_info=True)
|
| 594 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 595 |
+
|
| 596 |
+
@app.post("/api/predict/speech")
|
| 597 |
+
async def predict_speech(file: UploadFile = File(...), explain: bool = False):
|
| 598 |
+
"""Predict emotion from audio"""
|
| 599 |
+
try:
|
| 600 |
+
contents = await file.read()
|
| 601 |
+
suffix = _upload_suffix(file.filename, '.wav')
|
| 602 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 603 |
+
tmp.write(contents)
|
| 604 |
+
tmp_path = tmp.name
|
| 605 |
+
|
| 606 |
+
try:
|
| 607 |
+
audio, sr = librosa.load(tmp_path, sr=16000)
|
| 608 |
+
result = predict_speech_emotion(audio, sr, generate_explainability=explain)
|
| 609 |
+
return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
|
| 610 |
+
finally:
|
| 611 |
+
os.unlink(tmp_path)
|
| 612 |
+
except Exception as e:
|
| 613 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 614 |
+
|
| 615 |
+
@app.post("/api/predict/combined")
|
| 616 |
+
async def predict_combined(image_file: UploadFile = File(...), audio_file: UploadFile = File(...), explain: bool = False):
|
| 617 |
+
"""Predict emotions from both image and audio, then compare results"""
|
| 618 |
+
try:
|
| 619 |
+
image_contents = await image_file.read()
|
| 620 |
+
image = ImageOps.exif_transpose(Image.open(BytesIO(image_contents))).convert('RGB')
|
| 621 |
+
facial_result = predict_facial_emotion(image, generate_explainability=explain)
|
| 622 |
+
|
| 623 |
+
audio_contents = await audio_file.read()
|
| 624 |
+
audio_suffix = _upload_suffix(audio_file.filename, '.wav')
|
| 625 |
+
with tempfile.NamedTemporaryFile(suffix=audio_suffix, delete=False) as tmp:
|
| 626 |
+
tmp.write(audio_contents)
|
| 627 |
+
tmp_path = tmp.name
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
audio, sr = librosa.load(tmp_path, sr=16000)
|
| 631 |
+
speech_result = predict_speech_emotion(audio, sr, generate_explainability=explain)
|
| 632 |
+
finally:
|
| 633 |
+
os.unlink(tmp_path)
|
| 634 |
+
|
| 635 |
+
facial_emotion = facial_result["emotion"] if facial_result else None
|
| 636 |
+
facial_confidence = facial_result["confidence"] if facial_result else 0.0
|
| 637 |
+
|
| 638 |
+
speech_emotion = speech_result["emotion"] if speech_result else None
|
| 639 |
+
speech_confidence = speech_result["confidence"] if speech_result else 0.0
|
| 640 |
+
|
| 641 |
+
concordance, concordance_score = _calculate_concordance(
|
| 642 |
+
facial_emotion,
|
| 643 |
+
speech_emotion,
|
| 644 |
+
facial_confidence,
|
| 645 |
+
speech_confidence,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# The combined label should prefer the more confident modality when both are present.
|
| 649 |
+
combined_emotion = None
|
| 650 |
+
combined_confidence = 0.0
|
| 651 |
+
|
| 652 |
+
if facial_emotion and speech_emotion:
|
| 653 |
+
if facial_confidence > speech_confidence:
|
| 654 |
+
combined_emotion = facial_emotion
|
| 655 |
+
combined_confidence = facial_confidence
|
| 656 |
+
else:
|
| 657 |
+
combined_emotion = speech_emotion
|
| 658 |
+
combined_confidence = speech_confidence
|
| 659 |
+
elif facial_emotion:
|
| 660 |
+
combined_emotion = facial_emotion
|
| 661 |
+
combined_confidence = facial_confidence
|
| 662 |
+
elif speech_emotion:
|
| 663 |
+
combined_emotion = speech_emotion
|
| 664 |
+
combined_confidence = speech_confidence
|
| 665 |
+
|
| 666 |
+
response = {
|
| 667 |
+
"success": True,
|
| 668 |
+
"facial_emotion": {
|
| 669 |
+
"emotion": facial_emotion or "unknown",
|
| 670 |
+
"confidence": float(facial_confidence),
|
| 671 |
+
"probabilities": facial_result["probabilities"] if facial_result else {},
|
| 672 |
+
"face_detected": facial_result.get("face_detected", False) if facial_result else False,
|
| 673 |
+
"face_box": facial_result.get("face_box") if facial_result else None,
|
| 674 |
+
"annotated_image": facial_result.get("annotated_image") if facial_result else None
|
| 675 |
+
},
|
| 676 |
+
"speech_emotion": {
|
| 677 |
+
"emotion": speech_emotion or "unknown",
|
| 678 |
+
"confidence": float(speech_confidence),
|
| 679 |
+
"probabilities": speech_result["probabilities"] if speech_result else {}
|
| 680 |
+
},
|
| 681 |
+
"combined_emotion": combined_emotion or "unknown",
|
| 682 |
+
"combined_confidence": float(combined_confidence),
|
| 683 |
+
"concordance": concordance,
|
| 684 |
+
"concordance_score": concordance_score,
|
| 685 |
+
"analysis": {
|
| 686 |
+
"match": concordance == "MATCH",
|
| 687 |
+
"agreement_details": f"Face: {facial_emotion} (conf: {facial_confidence:.2f}) | Voice: {speech_emotion} (conf: {speech_confidence:.2f})"
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
if explain:
|
| 692 |
+
# Keep the response shape stable even when one modality fails to generate XAI output.
|
| 693 |
+
explainability = {}
|
| 694 |
+
errors = []
|
| 695 |
+
|
| 696 |
+
facial_status = (facial_result or {}).get("explainability_status") or {
|
| 697 |
+
"requested": True,
|
| 698 |
+
"generated": False,
|
| 699 |
+
"error": "Facial explainability unavailable"
|
| 700 |
+
}
|
| 701 |
+
speech_status = (speech_result or {}).get("explainability_status") or {
|
| 702 |
+
"requested": True,
|
| 703 |
+
"generated": False,
|
| 704 |
+
"error": "Speech explainability unavailable"
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
if facial_result and facial_result.get("grad_cam"):
|
| 708 |
+
explainability["grad_cam"] = facial_result.get("grad_cam")
|
| 709 |
+
elif facial_status.get("error"):
|
| 710 |
+
errors.append(f"Facial: {facial_status.get('error')}")
|
| 711 |
+
|
| 712 |
+
if speech_result and speech_result.get("saliency"):
|
| 713 |
+
explainability["saliency"] = speech_result.get("saliency")
|
| 714 |
+
elif speech_status.get("error"):
|
| 715 |
+
errors.append(f"Speech: {speech_status.get('error')}")
|
| 716 |
+
|
| 717 |
+
if speech_result and speech_result.get("waveform"):
|
| 718 |
+
explainability["waveform"] = speech_result.get("waveform")
|
| 719 |
+
|
| 720 |
+
response["explainability_status"] = {
|
| 721 |
+
"requested": True,
|
| 722 |
+
"generated": bool(explainability),
|
| 723 |
+
"facial": facial_status,
|
| 724 |
+
"speech": speech_status,
|
| 725 |
+
"errors": errors
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
if explainability:
|
| 729 |
+
response["explainability"] = explainability
|
| 730 |
+
|
| 731 |
+
return response
|
| 732 |
+
except Exception as e:
|
| 733 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 734 |
+
|
| 735 |
+
@app.post("/api/predict/video")
|
| 736 |
+
async def predict_video_emotion(file: UploadFile = File(...), explain: bool = False):
|
| 737 |
+
"""Predict emotions from video (facial + speech)"""
|
| 738 |
+
try:
|
| 739 |
+
video_suffix = _upload_suffix(file.filename, '.mp4')
|
| 740 |
+
with tempfile.NamedTemporaryFile(suffix=video_suffix, delete=False) as tmp:
|
| 741 |
+
contents = await file.read()
|
| 742 |
+
tmp.write(contents)
|
| 743 |
+
tmp_path = tmp.name
|
| 744 |
+
|
| 745 |
+
try:
|
| 746 |
+
processor = VideoProcessor()
|
| 747 |
+
frames, audio, sr, fps = processor.extract_frames_and_audio(tmp_path, fps_sample=5)
|
| 748 |
+
|
| 749 |
+
facial_results = []
|
| 750 |
+
for frame in frames[:10]:
|
| 751 |
+
result = predict_facial_emotion(frame)
|
| 752 |
+
if result:
|
| 753 |
+
facial_results.append(result)
|
| 754 |
+
|
| 755 |
+
if facial_results:
|
| 756 |
+
facial_emotions = [r["emotion"] for r in facial_results]
|
| 757 |
+
facial_confidence = np.mean([r["confidence"] for r in facial_results])
|
| 758 |
+
facial_emotion = max(set(facial_emotions), key=facial_emotions.count)
|
| 759 |
+
facial_probs = {}
|
| 760 |
+
for emotion in EMOTIONS_FACIAL:
|
| 761 |
+
facial_probs[emotion] = float(np.mean([r["probabilities"].get(emotion, 0) for r in facial_results]))
|
| 762 |
+
else:
|
| 763 |
+
facial_emotion = "unknown"
|
| 764 |
+
facial_confidence = 0.0
|
| 765 |
+
facial_probs = {e: 0.0 for e in EMOTIONS_FACIAL}
|
| 766 |
+
|
| 767 |
+
speech_result = predict_speech_emotion(audio, sr)
|
| 768 |
+
speech_emotion = speech_result["emotion"] if speech_result else "unknown"
|
| 769 |
+
speech_confidence = float(speech_result["confidence"]) if speech_result else 0.0
|
| 770 |
+
concordance, concordance_score = _calculate_concordance(
|
| 771 |
+
facial_emotion,
|
| 772 |
+
speech_emotion,
|
| 773 |
+
facial_confidence,
|
| 774 |
+
speech_confidence,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
response = {
|
| 778 |
+
"success": True,
|
| 779 |
+
"facial_emotion": {
|
| 780 |
+
"emotion": facial_emotion,
|
| 781 |
+
"confidence": float(facial_confidence),
|
| 782 |
+
"frames_analyzed": len(facial_results),
|
| 783 |
+
"probabilities": facial_probs
|
| 784 |
+
},
|
| 785 |
+
"speech_emotion": {
|
| 786 |
+
"emotion": speech_emotion,
|
| 787 |
+
"confidence": speech_confidence,
|
| 788 |
+
"probabilities": speech_result["probabilities"] if speech_result else {e: 0.0 for e in EMOTIONS_SPEECH}
|
| 789 |
+
},
|
| 790 |
+
"combined_emotion": facial_emotion if facial_confidence > 0.5 else (speech_result["emotion"] if speech_result else "unknown"),
|
| 791 |
+
"concordance": concordance,
|
| 792 |
+
"concordance_score": concordance_score,
|
| 793 |
+
"video_duration": float(len(audio) / sr),
|
| 794 |
+
"frames_processed": len(frames),
|
| 795 |
+
"fps": float(fps)
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
if explain:
|
| 799 |
+
explainability = {}
|
| 800 |
+
errors = []
|
| 801 |
+
|
| 802 |
+
facial_exp_status = {"requested": True, "generated": False, "error": None}
|
| 803 |
+
speech_exp_status = {"requested": True, "generated": False, "error": None}
|
| 804 |
+
|
| 805 |
+
if frames and facial_emotion != "unknown":
|
| 806 |
+
try:
|
| 807 |
+
# Run GradCAM on the best frame that predicted the aggregated facial_emotion
|
| 808 |
+
best_frame = None
|
| 809 |
+
best_result = None
|
| 810 |
+
best_conf = 0
|
| 811 |
+
for frame in frames[:10]:
|
| 812 |
+
r = predict_facial_emotion(frame)
|
| 813 |
+
# Find the frame that predicted the aggregated emotion with highest confidence
|
| 814 |
+
if r and r.get("emotion") == facial_emotion and r.get("confidence", 0) > best_conf:
|
| 815 |
+
best_conf = r["confidence"]
|
| 816 |
+
best_frame = frame
|
| 817 |
+
best_result = r
|
| 818 |
+
|
| 819 |
+
# If no frame predicted the aggregated emotion, use the first frame
|
| 820 |
+
if best_frame is None and frames:
|
| 821 |
+
best_frame = frames[0]
|
| 822 |
+
best_result = predict_facial_emotion(best_frame)
|
| 823 |
+
|
| 824 |
+
if best_frame is not None:
|
| 825 |
+
top_idx = EMOTIONS_FACIAL.index(facial_emotion) \
|
| 826 |
+
if facial_emotion in EMOTIONS_FACIAL else 0
|
| 827 |
+
# Crop face before passing to GradCAM
|
| 828 |
+
face_box, _ = _detect_primary_face(best_frame)
|
| 829 |
+
if face_box is not None:
|
| 830 |
+
frame_array = np.array(best_frame)
|
| 831 |
+
face_crop_array, _ = _crop_face_with_margin(frame_array, face_box)
|
| 832 |
+
gradcam_input = Image.fromarray(face_crop_array) if face_crop_array.size > 0 else best_frame
|
| 833 |
+
else:
|
| 834 |
+
gradcam_input = best_frame
|
| 835 |
+
orig_b64, heatmap_b64 = generate_grad_cam(
|
| 836 |
+
gradcam_input, vit_model, facial_processor,
|
| 837 |
+
top_idx, EMOTIONS_FACIAL, DEVICE
|
| 838 |
+
)
|
| 839 |
+
if heatmap_b64:
|
| 840 |
+
explainability["grad_cam"] = heatmap_b64
|
| 841 |
+
facial_exp_status["generated"] = True
|
| 842 |
+
else:
|
| 843 |
+
facial_exp_status["error"] = "GradCAM returned empty output"
|
| 844 |
+
except Exception as e:
|
| 845 |
+
facial_exp_status["error"] = str(e)
|
| 846 |
+
else:
|
| 847 |
+
facial_exp_status["error"] = "No valid frame prediction found for facial explainability"
|
| 848 |
+
|
| 849 |
+
if speech_result and speech_emotion != "unknown":
|
| 850 |
+
try:
|
| 851 |
+
top_idx = EMOTIONS_SPEECH.index(speech_emotion) \
|
| 852 |
+
if speech_emotion in EMOTIONS_SPEECH else 0
|
| 853 |
+
audio_for_xai = _trim_audio_window(audio, sr, max_seconds=MAX_SPEECH_XAI_SECONDS)
|
| 854 |
+
spec_b64, saliency_b64 = generate_audio_saliency(
|
| 855 |
+
audio_for_xai,
|
| 856 |
+
speech_model,
|
| 857 |
+
speech_processor,
|
| 858 |
+
top_idx,
|
| 859 |
+
EMOTIONS_SPEECH,
|
| 860 |
+
DEVICE,
|
| 861 |
+
sr=16000
|
| 862 |
+
)
|
| 863 |
+
if spec_b64:
|
| 864 |
+
explainability["waveform"] = spec_b64
|
| 865 |
+
if saliency_b64:
|
| 866 |
+
explainability["saliency"] = saliency_b64
|
| 867 |
+
speech_exp_status["generated"] = True
|
| 868 |
+
else:
|
| 869 |
+
speech_exp_status["error"] = "Audio saliency map returned empty output"
|
| 870 |
+
except Exception as e:
|
| 871 |
+
speech_exp_status["error"] = str(e)
|
| 872 |
+
else:
|
| 873 |
+
speech_exp_status["error"] = "No valid audio prediction found for explainability"
|
| 874 |
+
|
| 875 |
+
if facial_exp_status.get("error"):
|
| 876 |
+
errors.append(f"Facial: {facial_exp_status.get('error')}")
|
| 877 |
+
if speech_exp_status.get("error"):
|
| 878 |
+
errors.append(f"Speech: {speech_exp_status.get('error')}")
|
| 879 |
+
|
| 880 |
+
response["explainability_status"] = {
|
| 881 |
+
"requested": True,
|
| 882 |
+
"generated": bool(explainability),
|
| 883 |
+
"facial": facial_exp_status,
|
| 884 |
+
"speech": speech_exp_status,
|
| 885 |
+
"errors": errors
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
if explainability:
|
| 889 |
+
response["explainability"] = explainability
|
| 890 |
+
|
| 891 |
+
return response
|
| 892 |
+
finally:
|
| 893 |
+
os.unlink(tmp_path)
|
| 894 |
+
except Exception as e:
|
| 895 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 896 |
+
|
| 897 |
+
@app.get("/api/emotions/facial")
|
| 898 |
+
async def get_facial_emotions():
|
| 899 |
+
return {"emotions": EMOTIONS_FACIAL}
|
| 900 |
+
|
| 901 |
+
@app.get("/api/emotions/speech")
|
| 902 |
+
async def get_speech_emotions():
|
| 903 |
+
return {"emotions": EMOTIONS_SPEECH}
|
| 904 |
+
|
| 905 |
+
@app.get("/api/models/status")
|
| 906 |
+
async def get_models_status():
|
| 907 |
+
facial_ready = vit_model is not None and facial_processor is not None
|
| 908 |
+
speech_ready = speech_model is not None and speech_processor is not None
|
| 909 |
+
return {
|
| 910 |
+
"facial": {"loaded": facial_ready, "accuracy": 0.7129, "emotions": len(EMOTIONS_FACIAL)},
|
| 911 |
+
"speech": {"loaded": speech_ready, "accuracy": 0.8750, "emotions": len(EMOTIONS_SPEECH)},
|
| 912 |
+
"lazy_loading": not PRELOAD_MODELS,
|
| 913 |
+
"device": str(DEVICE)
|
| 914 |
+
}
|
backend/backend/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Service-layer utilities for backend inference and explainability."""
|
backend/backend/services/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (263 Bytes). View file
|
|
|
backend/backend/services/__pycache__/explainability.cpython-314.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
backend/backend/services/data_loader.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loaders used by the training and experimentation workflows."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import librosa
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ==================== FACIAL DATASET ====================
|
| 13 |
+
class FER2013Dataset(Dataset):
|
| 14 |
+
"""FER2013 facial emotion dataset loader."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, root_dir: str, split: str = "train", transform=None):
|
| 17 |
+
"""
|
| 18 |
+
Initialize FER2013 dataset.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
root_dir: Root directory containing 'train' and 'test' folders
|
| 22 |
+
split: 'train' or 'test'
|
| 23 |
+
transform: Torchvision transforms to apply
|
| 24 |
+
"""
|
| 25 |
+
self.root_dir = root_dir
|
| 26 |
+
self.split = split
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.emotions = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
| 29 |
+
self.emotion2idx = {e: i for i, e in enumerate(self.emotions)}
|
| 30 |
+
|
| 31 |
+
self.samples = []
|
| 32 |
+
self._load_samples()
|
| 33 |
+
|
| 34 |
+
def _load_samples(self):
|
| 35 |
+
"""Load all image paths and labels."""
|
| 36 |
+
split_dir = os.path.join(self.root_dir, self.split)
|
| 37 |
+
|
| 38 |
+
for emotion in self.emotions:
|
| 39 |
+
emotion_dir = os.path.join(split_dir, emotion)
|
| 40 |
+
if not os.path.exists(emotion_dir):
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
for img_file in os.listdir(emotion_dir):
|
| 44 |
+
if img_file.endswith(('.jpg', '.jpeg', '.png')):
|
| 45 |
+
img_path = os.path.join(emotion_dir, img_file)
|
| 46 |
+
self.samples.append((img_path, self.emotion2idx[emotion]))
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.samples)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
img_path, label = self.samples[idx]
|
| 53 |
+
|
| 54 |
+
# Load image
|
| 55 |
+
image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
| 56 |
+
if image is None:
|
| 57 |
+
return torch.zeros(3, 224, 224), torch.tensor(label, dtype=torch.long)
|
| 58 |
+
|
| 59 |
+
# Convert to RGB (3 channels)
|
| 60 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 61 |
+
|
| 62 |
+
if self.transform:
|
| 63 |
+
image = self.transform(image)
|
| 64 |
+
else:
|
| 65 |
+
# Default transform
|
| 66 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
|
| 67 |
+
|
| 68 |
+
return image, torch.tensor(label, dtype=torch.long)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ==================== AUDIO DATASET ====================
|
| 72 |
+
class RAVDESSDataset(Dataset):
|
| 73 |
+
"""RAVDESS audio emotion dataset loader."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, root_dir: str, n_mfcc: int = 13, target_sr: int = 22050):
|
| 76 |
+
"""
|
| 77 |
+
Initialize RAVDESS dataset.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
root_dir: Root directory containing audio files
|
| 81 |
+
n_mfcc: Number of MFCCs to extract
|
| 82 |
+
target_sr: Target sampling rate
|
| 83 |
+
"""
|
| 84 |
+
self.root_dir = root_dir
|
| 85 |
+
self.n_mfcc = n_mfcc
|
| 86 |
+
self.target_sr = target_sr
|
| 87 |
+
self.emotion_map = {
|
| 88 |
+
'01': 'neutral',
|
| 89 |
+
'02': 'calm',
|
| 90 |
+
'03': 'happy',
|
| 91 |
+
'04': 'sad',
|
| 92 |
+
'05': 'angry',
|
| 93 |
+
'06': 'fear',
|
| 94 |
+
'07': 'disgust',
|
| 95 |
+
'08': 'surprise'
|
| 96 |
+
}
|
| 97 |
+
self.emotion2idx = {v: i for i, v in enumerate(set(self.emotion_map.values()))}
|
| 98 |
+
|
| 99 |
+
self.samples = []
|
| 100 |
+
self._load_samples()
|
| 101 |
+
|
| 102 |
+
def _load_samples(self):
|
| 103 |
+
"""Load all audio file paths and labels."""
|
| 104 |
+
for file in os.listdir(self.root_dir):
|
| 105 |
+
if file.endswith('.wav'):
|
| 106 |
+
emotion_code = file.split('-')[2]
|
| 107 |
+
if emotion_code in self.emotion_map:
|
| 108 |
+
emotion = self.emotion_map[emotion_code]
|
| 109 |
+
audio_path = os.path.join(self.root_dir, file)
|
| 110 |
+
self.samples.append((audio_path, self.emotion2idx[emotion]))
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.samples)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, idx):
|
| 116 |
+
audio_path, label = self.samples[idx]
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
y, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
|
| 120 |
+
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)
|
| 121 |
+
|
| 122 |
+
# Normalize MFCC
|
| 123 |
+
mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
|
| 124 |
+
|
| 125 |
+
# Pad or truncate to fixed size (100 time steps)
|
| 126 |
+
if mfcc.shape[1] < 100:
|
| 127 |
+
mfcc = np.pad(mfcc, ((0, 0), (0, 100 - mfcc.shape[1])), mode='constant')
|
| 128 |
+
else:
|
| 129 |
+
mfcc = mfcc[:, :100]
|
| 130 |
+
|
| 131 |
+
return torch.from_numpy(mfcc).float(), torch.tensor(label, dtype=torch.long)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error loading {audio_path}: {e}")
|
| 134 |
+
return torch.zeros(self.n_mfcc, 100), torch.tensor(label, dtype=torch.long)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ==================== DATALOADER FACTORY ====================
|
| 138 |
+
def create_dataloaders(
|
| 139 |
+
fer2013_dir: str = None,
|
| 140 |
+
ravdess_dir: str = None,
|
| 141 |
+
batch_size: int = 32,
|
| 142 |
+
num_workers: int = 0,
|
| 143 |
+
img_size: int = 224
|
| 144 |
+
) -> dict:
|
| 145 |
+
"""
|
| 146 |
+
Create dataloaders for FER2013 and RAVDESS datasets.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
fer2013_dir: Path to FER2013 dataset root
|
| 150 |
+
ravdess_dir: Path to RAVDESS dataset root
|
| 151 |
+
batch_size: Batch size for training
|
| 152 |
+
num_workers: Number of workers for data loading
|
| 153 |
+
img_size: Image size for FER2013
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Dictionary with dataloaders for each dataset
|
| 157 |
+
"""
|
| 158 |
+
transform = transforms.Compose([
|
| 159 |
+
transforms.ToPILImage(),
|
| 160 |
+
transforms.Resize((img_size, img_size)),
|
| 161 |
+
transforms.ToTensor(),
|
| 162 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 163 |
+
std=[0.229, 0.224, 0.225])
|
| 164 |
+
])
|
| 165 |
+
|
| 166 |
+
dataloaders = {}
|
| 167 |
+
|
| 168 |
+
if fer2013_dir and os.path.exists(fer2013_dir):
|
| 169 |
+
train_dataset = FER2013Dataset(fer2013_dir, split='train', transform=transform)
|
| 170 |
+
test_dataset = FER2013Dataset(fer2013_dir, split='test', transform=transform)
|
| 171 |
+
|
| 172 |
+
dataloaders['fer2013_train'] = DataLoader(
|
| 173 |
+
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
| 174 |
+
)
|
| 175 |
+
dataloaders['fer2013_test'] = DataLoader(
|
| 176 |
+
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if ravdess_dir and os.path.exists(ravdess_dir):
|
| 180 |
+
audio_dataset = RAVDESSDataset(ravdess_dir)
|
| 181 |
+
dataloaders['ravdess'] = DataLoader(
|
| 182 |
+
audio_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return dataloaders
|
backend/backend/services/explainability.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Explainability utilities for multimodal emotion recognition outputs."""
|
| 2 |
+
import os
|
| 3 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "0"
|
| 4 |
+
os.environ["QT_QPA_PLATFORM"] = "offscreen"
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import librosa
|
| 10 |
+
import librosa.display
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from io import BytesIO
|
| 16 |
+
import base64
|
| 17 |
+
|
| 18 |
+
from pytorch_grad_cam import GradCAM, EigenCAM
|
| 19 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 20 |
+
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ==================== MODEL WRAPPER ====================
|
| 24 |
+
class ViTLogitsWrapper(nn.Module):
|
| 25 |
+
def __init__(self, model):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.model = model
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
# Grad-CAM expects a standard forward() that returns logits for the selected class.
|
| 31 |
+
return self.model(pixel_values=x).logits
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ==================== FACIAL EXPLAINABILITY ====================
|
| 35 |
+
def generate_grad_cam(image, model, processor, emotion_idx, emotions_list, device):
|
| 36 |
+
try:
|
| 37 |
+
img_rgb = np.array(image.convert('RGB'))
|
| 38 |
+
h, w = img_rgb.shape[:2]
|
| 39 |
+
img_pil = Image.fromarray(img_rgb)
|
| 40 |
+
|
| 41 |
+
inputs = processor(img_pil, return_tensors='pt').to(device)
|
| 42 |
+
input_tensor = inputs['pixel_values']
|
| 43 |
+
|
| 44 |
+
wrapped_model = ViTLogitsWrapper(model)
|
| 45 |
+
wrapped_model.eval()
|
| 46 |
+
|
| 47 |
+
# Try multiple layers because the last block can become too saturated for a usable heatmap.
|
| 48 |
+
layers_to_try = [
|
| 49 |
+
model.vit.encoder.layer[-1].layernorm_after,
|
| 50 |
+
model.vit.encoder.layer[-2].layernorm_after,
|
| 51 |
+
model.vit.encoder.layer[-3].layernorm_after,
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
cam_map = None
|
| 55 |
+
method_used = None
|
| 56 |
+
|
| 57 |
+
for i, layer in enumerate(layers_to_try):
|
| 58 |
+
try:
|
| 59 |
+
cam = GradCAM(
|
| 60 |
+
model=wrapped_model,
|
| 61 |
+
target_layers=[layer],
|
| 62 |
+
reshape_transform=vit_reshape_transform,
|
| 63 |
+
)
|
| 64 |
+
targets = [ClassifierOutputTarget(emotion_idx)]
|
| 65 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
| 66 |
+
result = grayscale_cam[0]
|
| 67 |
+
|
| 68 |
+
# Reject degenerate maps so the UI never shows a blank explanation as if it were valid.
|
| 69 |
+
if result.max() > 0.01:
|
| 70 |
+
cam_map = result
|
| 71 |
+
method_used = f"GradCAM (encoder block {12 - (i+1)})"
|
| 72 |
+
break
|
| 73 |
+
else:
|
| 74 |
+
print(f"[explainability] layer[-{i+1}] all zeros, trying next")
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"[explainability] GradCAM layer[-{i+1}] failed: {e}")
|
| 78 |
+
|
| 79 |
+
# Final fallback: EigenCAM gives a stable PCA-based map when gradients are unhelpful.
|
| 80 |
+
if cam_map is None:
|
| 81 |
+
print("[explainability] All GradCAM layers zero, using EigenCAM")
|
| 82 |
+
try:
|
| 83 |
+
eigen = EigenCAM(
|
| 84 |
+
model=wrapped_model,
|
| 85 |
+
target_layers=[model.vit.encoder.layer[-1].layernorm_after],
|
| 86 |
+
reshape_transform=vit_reshape_transform,
|
| 87 |
+
)
|
| 88 |
+
grayscale_cam = eigen(input_tensor=input_tensor)
|
| 89 |
+
cam_map = grayscale_cam[0]
|
| 90 |
+
method_used = "EigenCAM"
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"[explainability] EigenCAM failed: {e}")
|
| 93 |
+
return None, None
|
| 94 |
+
|
| 95 |
+
print(f"[explainability] {method_used} — min={cam_map.min():.3f}, max={cam_map.max():.3f}")
|
| 96 |
+
|
| 97 |
+
# Upscale and smooth the heatmap so it overlays cleanly on the source image.
|
| 98 |
+
cam_resized = cv2.resize(cam_map.astype(np.float32), (w, h), interpolation=cv2.INTER_CUBIC)
|
| 99 |
+
cam_resized = cv2.GaussianBlur(cam_resized, (13, 13), 0)
|
| 100 |
+
|
| 101 |
+
c_min, c_max = cam_resized.min(), cam_resized.max()
|
| 102 |
+
if c_max > c_min:
|
| 103 |
+
cam_resized = (cam_resized - c_min) / (c_max - c_min)
|
| 104 |
+
|
| 105 |
+
# Build the colored overlay and blend only the most salient regions.
|
| 106 |
+
cam_uint8 = np.uint8(255 * cam_resized)
|
| 107 |
+
heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
|
| 108 |
+
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
|
| 109 |
+
|
| 110 |
+
threshold = np.percentile(cam_resized, 70)
|
| 111 |
+
blend_mask = (cam_resized > threshold).astype(np.float32)
|
| 112 |
+
blend_mask = cv2.GaussianBlur(blend_mask, (31, 31), 0)[..., None]
|
| 113 |
+
|
| 114 |
+
blended = (
|
| 115 |
+
(1 - blend_mask * 0.65) * img_rgb.astype(np.float32)
|
| 116 |
+
+ blend_mask * 0.65 * heatmap_rgb.astype(np.float32)
|
| 117 |
+
).clip(0, 255).astype(np.uint8)
|
| 118 |
+
|
| 119 |
+
orig_buf = BytesIO()
|
| 120 |
+
Image.fromarray(img_rgb).save(orig_buf, format='PNG')
|
| 121 |
+
orig_b64 = base64.b64encode(orig_buf.getvalue()).decode()
|
| 122 |
+
|
| 123 |
+
blend_buf = BytesIO()
|
| 124 |
+
Image.fromarray(blended).save(blend_buf, format='PNG')
|
| 125 |
+
blend_b64 = base64.b64encode(blend_buf.getvalue()).decode()
|
| 126 |
+
|
| 127 |
+
return orig_b64, blend_b64
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"[explainability] GradCAM generation failed: {e}")
|
| 131 |
+
return None, None
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ==================== AUDIO EXPLAINABILITY ====================
|
| 135 |
+
def generate_audio_saliency(audio, model, processor, emotion_idx, emotions_list, device, sr=16000):
|
| 136 |
+
try:
|
| 137 |
+
if audio is None or len(audio) == 0:
|
| 138 |
+
raise ValueError("Audio input is empty")
|
| 139 |
+
|
| 140 |
+
# Sanitize the audio before passing it into the speech backbone.
|
| 141 |
+
audio = np.asarray(audio, dtype=np.float32)
|
| 142 |
+
audio = np.nan_to_num(audio)
|
| 143 |
+
|
| 144 |
+
inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
|
| 145 |
+
input_values = inputs['input_values'].to(device)
|
| 146 |
+
|
| 147 |
+
input_values.requires_grad = True
|
| 148 |
+
model.zero_grad()
|
| 149 |
+
|
| 150 |
+
outputs = model(input_values)
|
| 151 |
+
score = outputs.logits[0, emotion_idx]
|
| 152 |
+
score.backward()
|
| 153 |
+
|
| 154 |
+
if input_values.grad is None:
|
| 155 |
+
raise RuntimeError("No gradients captured")
|
| 156 |
+
|
| 157 |
+
saliency = torch.abs(input_values.grad).cpu().detach().numpy()
|
| 158 |
+
|
| 159 |
+
if saliency.ndim == 3:
|
| 160 |
+
saliency = np.mean(saliency, axis=1)[0]
|
| 161 |
+
elif saliency.ndim == 2:
|
| 162 |
+
saliency = saliency[0]
|
| 163 |
+
saliency = saliency.reshape(-1).astype(np.float32)
|
| 164 |
+
|
| 165 |
+
if saliency.size > 11:
|
| 166 |
+
# Smooth the gradient spikes so the curve is readable in the plot.
|
| 167 |
+
kernel = np.ones(11, dtype=np.float32) / 11.0
|
| 168 |
+
saliency = np.convolve(saliency, kernel, mode='same')
|
| 169 |
+
|
| 170 |
+
s_min, s_max = saliency.min(), saliency.max()
|
| 171 |
+
if s_max > s_min:
|
| 172 |
+
saliency = (saliency - s_min) / (s_max - s_min)
|
| 173 |
+
else:
|
| 174 |
+
saliency = np.zeros_like(saliency)
|
| 175 |
+
|
| 176 |
+
# Build both the spectrogram and the saliency overlay for a side-by-side explanation.
|
| 177 |
+
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
|
| 178 |
+
S_db = librosa.power_to_db(S, ref=np.max)
|
| 179 |
+
|
| 180 |
+
fig1, ax1 = plt.subplots(figsize=(10, 4), dpi=100)
|
| 181 |
+
img1 = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax1)
|
| 182 |
+
ax1.set_title(f'Audio Spectrogram — {emotions_list[emotion_idx]}')
|
| 183 |
+
fig1.colorbar(img1, ax=ax1, format='%+2.0f dB')
|
| 184 |
+
spec_buf = BytesIO()
|
| 185 |
+
fig1.savefig(spec_buf, format='PNG', bbox_inches='tight', dpi=100)
|
| 186 |
+
spec_b64 = base64.b64encode(spec_buf.getvalue()).decode()
|
| 187 |
+
plt.close(fig1)
|
| 188 |
+
|
| 189 |
+
fig2, (ax2, ax3) = plt.subplots(2, 1, figsize=(10, 5.5), dpi=100,
|
| 190 |
+
gridspec_kw={'height_ratios': [3, 1]}, sharex=False)
|
| 191 |
+
|
| 192 |
+
# Normalize the spectrogram so the saliency colors stay visible across different recordings.
|
| 193 |
+
S_norm = (S_db - S_db.min()) / max(S_db.max() - S_db.min(), 1e-8)
|
| 194 |
+
sal_resized = np.interp(np.linspace(0, 1, S_db.shape[1]),
|
| 195 |
+
np.linspace(0, 1, saliency.shape[0]), saliency)
|
| 196 |
+
sal_map = np.tile(sal_resized, (S_db.shape[0], 1))
|
| 197 |
+
|
| 198 |
+
ax2.imshow(S_norm, aspect='auto', origin='lower', cmap='viridis', interpolation='bilinear')
|
| 199 |
+
ax2.imshow(sal_map, aspect='auto', origin='lower', cmap='magma', alpha=0.6, interpolation='bilinear')
|
| 200 |
+
ax2.set_title(f'Audio Saliency — {emotions_list[emotion_idx]} (bright = important)')
|
| 201 |
+
ax2.set_ylabel('Mel Frequency')
|
| 202 |
+
|
| 203 |
+
# Highlight the strongest time steps to give the user a clear peak view.
|
| 204 |
+
peak_thr = np.percentile(sal_resized, 85)
|
| 205 |
+
x = np.arange(len(sal_resized))
|
| 206 |
+
ax3.plot(x, sal_resized, color='#f97316', linewidth=1.5)
|
| 207 |
+
ax3.fill_between(x, 0, sal_resized, where=sal_resized >= peak_thr, color='#ef4444', alpha=0.4)
|
| 208 |
+
ax3.axhline(peak_thr, color='#ef4444', linestyle='--', linewidth=1, alpha=0.8)
|
| 209 |
+
ax3.set_ylim(0, 1.05)
|
| 210 |
+
ax3.set_ylabel('Saliency')
|
| 211 |
+
ax3.set_xlabel('Time steps')
|
| 212 |
+
ax3.grid(alpha=0.2)
|
| 213 |
+
|
| 214 |
+
sal_buf = BytesIO()
|
| 215 |
+
fig2.tight_layout()
|
| 216 |
+
fig2.savefig(sal_buf, format='PNG', bbox_inches='tight', dpi=100)
|
| 217 |
+
sal_b64 = base64.b64encode(sal_buf.getvalue()).decode()
|
| 218 |
+
plt.close(fig2)
|
| 219 |
+
|
| 220 |
+
return spec_b64, sal_b64
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
print(f"[explainability] Audio saliency failed: {e}")
|
| 224 |
+
return None, None
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ==================== COMBINED VISUALIZATION ====================
|
| 228 |
+
def create_combined_visualization(grad_cam_base64, saliency_base64, facial_emotion, speech_emotion, concordance):
|
| 229 |
+
try:
|
| 230 |
+
# Use a soft status tint so the combined report communicates agreement at a glance.
|
| 231 |
+
bg_color = '#d4edda' if concordance == 'MATCH' else '#f8d7da'
|
| 232 |
+
html = f"""
|
| 233 |
+
<div style="display:flex;gap:20px;padding:20px;background:#f5f5f5;border-radius:10px;">
|
| 234 |
+
<div style="flex:1;">
|
| 235 |
+
<h3>Facial GradCAM — {facial_emotion}</h3>
|
| 236 |
+
<img src="data:image/png;base64,{grad_cam_base64}" style="width:100%;border-radius:8px;">
|
| 237 |
+
<p style="font-size:12px;color:#666;">Red/warm = regions that most influenced the {facial_emotion} prediction.</p>
|
| 238 |
+
</div>
|
| 239 |
+
<div style="flex:1;">
|
| 240 |
+
<h3>Speech Saliency — {speech_emotion}</h3>
|
| 241 |
+
<img src="data:image/png;base64,{saliency_base64}" style="width:100%;border-radius:8px;">
|
| 242 |
+
<p style="font-size:12px;color:#666;">Bright = time-frequency regions with strongest influence.</p>
|
| 243 |
+
</div>
|
| 244 |
+
</div>
|
| 245 |
+
<div style="margin-top:20px;padding:15px;background:{bg_color};border-radius:8px;text-align:center;">
|
| 246 |
+
<h4 style="margin:0;">Concordance: <strong>{concordance}</strong></h4>
|
| 247 |
+
</div>
|
| 248 |
+
"""
|
| 249 |
+
return base64.b64encode(html.encode()).decode()
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"[explainability] Combined visualisation failed: {e}")
|
| 252 |
+
return None
|
backend/main.py
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend for multimodal (facial + speech) emotion inference."""
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI, File, UploadFile
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import librosa
|
| 10 |
+
import base64
|
| 11 |
+
from PIL import Image, ImageOps
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoFeatureExtractor, AutoModelForAudioClassification
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
import tempfile
|
| 17 |
+
import os
|
| 18 |
+
import logging
|
| 19 |
+
from threading import Lock
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from facenet_pytorch import MTCNN # type: ignore[import-not-found]
|
| 24 |
+
except Exception:
|
| 25 |
+
MTCNN = None
|
| 26 |
+
|
| 27 |
+
# Load environment variables
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(
|
| 32 |
+
level=logging.INFO,
|
| 33 |
+
format='[%(asctime)s] [%(levelname)s] %(message)s'
|
| 34 |
+
)
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
# Explainability helpers
|
| 38 |
+
from backend.services.explainability import generate_grad_cam, generate_audio_saliency
|
| 39 |
+
|
| 40 |
+
ENV = os.getenv("ENV", "development")
|
| 41 |
+
FRONTEND_URL = os.getenv(
|
| 42 |
+
"FRONTEND_URL",
|
| 43 |
+
os.getenv("REACT_APP_VERCEL_URL", "http://localhost:3000")
|
| 44 |
+
)
|
| 45 |
+
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "")
|
| 46 |
+
USE_GPU = os.getenv("USE_GPU", "true").lower() == "true"
|
| 47 |
+
PRELOAD_MODELS = os.getenv("PRELOAD_MODELS", "false").lower() == "true"
|
| 48 |
+
ENABLE_FACE_ROTATION = os.getenv("ENABLE_FACE_ROTATION", "false").lower() == "true"
|
| 49 |
+
MAX_FACE_ROTATION_DEGREES = float(os.getenv("MAX_FACE_ROTATION_DEGREES", "8"))
|
| 50 |
+
HAAR_MIN_NEIGHBORS = int(os.getenv("HAAR_MIN_NEIGHBORS", "5"))
|
| 51 |
+
HAAR_MIN_SIZE = int(os.getenv("HAAR_MIN_SIZE", "40"))
|
| 52 |
+
|
| 53 |
+
app = FastAPI(title="Multi-Modal Emotion Recognition API", version="2.0.0")
|
| 54 |
+
|
| 55 |
+
# Configure CORS based on environment
|
| 56 |
+
if ENV == "production":
|
| 57 |
+
if CORS_ORIGINS.strip():
|
| 58 |
+
allowed_origins = [origin.strip() for origin in CORS_ORIGINS.split(",") if origin.strip()]
|
| 59 |
+
else:
|
| 60 |
+
allowed_origins = [FRONTEND_URL]
|
| 61 |
+
else:
|
| 62 |
+
allowed_origins = ["*"]
|
| 63 |
+
|
| 64 |
+
app.add_middleware(
|
| 65 |
+
CORSMiddleware,
|
| 66 |
+
allow_origins=allowed_origins,
|
| 67 |
+
allow_credentials=True,
|
| 68 |
+
allow_methods=["*"],
|
| 69 |
+
allow_headers=["*"],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
logger.info(f"CORS enabled for: {allowed_origins}")
|
| 73 |
+
logger.info(
|
| 74 |
+
"Face detection config: rotation=%s max_rotation=%.1f haar_min_neighbors=%d haar_min_size=%d",
|
| 75 |
+
ENABLE_FACE_ROTATION,
|
| 76 |
+
MAX_FACE_ROTATION_DEGREES,
|
| 77 |
+
HAAR_MIN_NEIGHBORS,
|
| 78 |
+
HAAR_MIN_SIZE,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Runtime configuration
|
| 82 |
+
EMOTIONS_FACIAL = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
| 83 |
+
EMOTIONS_SPEECH = ['angry', 'calm', 'disgust', 'fearful', 'happy', 'neutral', 'sad', 'surprised']
|
| 84 |
+
DEVICE = torch.device('cuda' if (torch.cuda.is_available() and USE_GPU) else 'cpu')
|
| 85 |
+
MAX_SPEECH_INFER_SECONDS = int(os.getenv('MAX_SPEECH_INFER_SECONDS', '15'))
|
| 86 |
+
MAX_SPEECH_XAI_SECONDS = int(os.getenv('MAX_SPEECH_XAI_SECONDS', '8'))
|
| 87 |
+
CONCORDANCE_SCORE_MAP = {
|
| 88 |
+
'MATCH': 100,
|
| 89 |
+
'PARTIAL': 65,
|
| 90 |
+
'MISMATCH': 30,
|
| 91 |
+
'UNKNOWN': 0,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# In-memory model state
|
| 95 |
+
vit_model = None
|
| 96 |
+
facial_processor = None
|
| 97 |
+
speech_model = None
|
| 98 |
+
speech_processor = None
|
| 99 |
+
facial_loaded = False
|
| 100 |
+
speech_loaded = False
|
| 101 |
+
|
| 102 |
+
_facial_model_lock = Lock()
|
| 103 |
+
_speech_model_lock = Lock()
|
| 104 |
+
|
| 105 |
+
# Paths — download from HuggingFace Hub
|
| 106 |
+
logger.info("Resolving model paths from HuggingFace Hub...")
|
| 107 |
+
FACIAL_MODEL_PATH = hf_hub_download(
|
| 108 |
+
repo_id="Nishvaraj/emotion-models",
|
| 109 |
+
filename="vit_emotion_model.pt"
|
| 110 |
+
)
|
| 111 |
+
SPEECH_MODEL_PATH = hf_hub_download(
|
| 112 |
+
repo_id="Nishvaraj/emotion-models",
|
| 113 |
+
filename="hubert_emotion_model.pt"
|
| 114 |
+
)
|
| 115 |
+
logger.info(f"Facial model path: {FACIAL_MODEL_PATH}")
|
| 116 |
+
logger.info(f"Speech model path: {SPEECH_MODEL_PATH}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _upload_suffix(filename: str, default_suffix: str) -> str:
|
| 120 |
+
# Preserve the original extension when the browser provides one, otherwise fall back to a safe default.
|
| 121 |
+
suffix = Path(filename or '').suffix.lower()
|
| 122 |
+
return suffix if suffix else default_suffix
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _calculate_concordance(facial_emotion, speech_emotion, facial_confidence, speech_confidence):
|
| 126 |
+
# Match/partial/mismatch is derived from whether both models agree and how confident they are.
|
| 127 |
+
if facial_emotion == speech_emotion:
|
| 128 |
+
# When the modalities agree, the average confidence controls the concordance band.
|
| 129 |
+
score = (facial_confidence + speech_confidence) / 2
|
| 130 |
+
if score > 0.7:
|
| 131 |
+
concordance = "MATCH"
|
| 132 |
+
elif score >= 0.4:
|
| 133 |
+
concordance = "PARTIAL"
|
| 134 |
+
else:
|
| 135 |
+
concordance = "MISMATCH"
|
| 136 |
+
else:
|
| 137 |
+
# Different emotions can never be a full match, so we score by how close the confidences are.
|
| 138 |
+
score = 1 - abs(facial_confidence - speech_confidence)
|
| 139 |
+
if score >= 0.5:
|
| 140 |
+
concordance = "PARTIAL"
|
| 141 |
+
else:
|
| 142 |
+
concordance = "MISMATCH"
|
| 143 |
+
|
| 144 |
+
concordance_score = round(score * 100)
|
| 145 |
+
return concordance, concordance_score
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
FACE_CASCADE = cv2.CascadeClassifier(
|
| 149 |
+
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 150 |
+
)
|
| 151 |
+
MTCNN_DETECTOR = MTCNN(keep_all=False, device=DEVICE) if MTCNN is not None else None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _encode_image_base64(image_array: np.ndarray) -> str:
|
| 155 |
+
image_pil = Image.fromarray(image_array.astype(np.uint8))
|
| 156 |
+
buf = BytesIO()
|
| 157 |
+
image_pil.save(buf, format='PNG')
|
| 158 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _detect_primary_face(image: Image.Image):
|
| 162 |
+
# Prefer MTCNN when available because it gives stronger boxes and landmark points.
|
| 163 |
+
if MTCNN_DETECTOR is not None:
|
| 164 |
+
try:
|
| 165 |
+
boxes, probs, points = MTCNN_DETECTOR.detect(image, landmarks=True)
|
| 166 |
+
if boxes is not None and len(boxes) > 0:
|
| 167 |
+
# Use the highest-probability detection when multiple faces appear.
|
| 168 |
+
best_idx = int(np.argmax(probs)) if probs is not None else 0
|
| 169 |
+
x1, y1, x2, y2 = boxes[best_idx]
|
| 170 |
+
# Convert from [x1,y1,x2,y2] to [x,y,w,h]
|
| 171 |
+
x, y, w, h = int(x1), int(y1), int(x2 - x1), int(y2 - y1)
|
| 172 |
+
return (x, y, w, h), (points[best_idx] if points is not None else None)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.debug(f"MTCNN face detection fallback: {e}")
|
| 175 |
+
|
| 176 |
+
# Haar cascade is the fallback path so the app still works without facenet-pytorch.
|
| 177 |
+
img_array = np.array(image)
|
| 178 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 179 |
+
faces = FACE_CASCADE.detectMultiScale(
|
| 180 |
+
gray,
|
| 181 |
+
scaleFactor=1.1,
|
| 182 |
+
minNeighbors=HAAR_MIN_NEIGHBORS,
|
| 183 |
+
minSize=(HAAR_MIN_SIZE, HAAR_MIN_SIZE)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if faces is None or len(faces) == 0:
|
| 187 |
+
return None, None
|
| 188 |
+
best_face = max(faces, key=lambda b: b[2] * b[3])
|
| 189 |
+
return tuple(int(v) for v in best_face), None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _rotate_image_to_level(image: Image.Image, points) -> Image.Image:
|
| 193 |
+
if not ENABLE_FACE_ROTATION:
|
| 194 |
+
return image
|
| 195 |
+
|
| 196 |
+
if points is None:
|
| 197 |
+
return image
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
# Estimate head tilt from the eye landmarks and keep the correction bounded.
|
| 201 |
+
left_eye, right_eye = points[0], points[1]
|
| 202 |
+
angle = np.degrees(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))
|
| 203 |
+
if abs(angle) < 1.0:
|
| 204 |
+
return image
|
| 205 |
+
if abs(angle) > MAX_FACE_ROTATION_DEGREES:
|
| 206 |
+
logger.debug("Skipping face rotation due to large angle: %.2f", angle)
|
| 207 |
+
return image
|
| 208 |
+
center_x = image.width / 2
|
| 209 |
+
center_y = image.height / 2
|
| 210 |
+
return image.rotate(-angle, resample=Image.Resampling.BICUBIC, expand=True, center=(center_x, center_y), fillcolor=(0, 0, 0))
|
| 211 |
+
except Exception:
|
| 212 |
+
return image
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _crop_face_with_margin(image_array: np.ndarray, face_box, margin_ratio: float = 0.12):
|
| 216 |
+
# Expand the detected face slightly so the classifier keeps some surrounding context.
|
| 217 |
+
x, y, w, h = [int(v) for v in face_box]
|
| 218 |
+
h_img, w_img = image_array.shape[:2]
|
| 219 |
+
mx = int(w * margin_ratio)
|
| 220 |
+
my = int(h * margin_ratio)
|
| 221 |
+
|
| 222 |
+
x1 = max(0, x - mx)
|
| 223 |
+
y1 = max(0, y - my)
|
| 224 |
+
x2 = min(w_img, x + w + mx)
|
| 225 |
+
y2 = min(h_img, y + h + my)
|
| 226 |
+
|
| 227 |
+
return image_array[y1:y2, x1:x2], (x1, y1, x2 - x1, y2 - y1)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _shrink_box(face_box, shrink_ratio: float = 0.12):
|
| 231 |
+
# Draw a tighter outline for annotation so the face box looks cleaner on the preview image.
|
| 232 |
+
x, y, w, h = [int(v) for v in face_box]
|
| 233 |
+
dx = int(w * shrink_ratio / 2)
|
| 234 |
+
dy = int(h * shrink_ratio / 2)
|
| 235 |
+
x1 = x + dx
|
| 236 |
+
y1 = y + dy
|
| 237 |
+
width = max(1, w - (dx * 2))
|
| 238 |
+
height = max(1, h - (dy * 2))
|
| 239 |
+
return x1, y1, width, height
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _trim_audio_window(audio: np.ndarray, sr: int, max_seconds: int) -> np.ndarray:
|
| 243 |
+
# Long recordings are centered and clipped so inference stays fast and consistent.
|
| 244 |
+
if audio is None or sr <= 0:
|
| 245 |
+
return audio
|
| 246 |
+
max_len = int(sr * max_seconds)
|
| 247 |
+
if max_len <= 0 or len(audio) <= max_len:
|
| 248 |
+
return audio
|
| 249 |
+
start = (len(audio) - max_len) // 2
|
| 250 |
+
end = start + max_len
|
| 251 |
+
return audio[start:end]
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
logger.info(f"Device: {DEVICE}")
|
| 255 |
+
logger.info(f"Environment: {ENV}")
|
| 256 |
+
|
| 257 |
+
# ========== MODEL LOADING ==========
|
| 258 |
+
|
| 259 |
+
def load_facial_model():
|
| 260 |
+
"""Load ViT model for facial emotion"""
|
| 261 |
+
global vit_model, facial_processor, facial_loaded
|
| 262 |
+
if vit_model is not None and facial_processor is not None:
|
| 263 |
+
facial_loaded = True
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
with _facial_model_lock:
|
| 267 |
+
if vit_model is not None and facial_processor is not None:
|
| 268 |
+
facial_loaded = True
|
| 269 |
+
return True
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
logger.info("Loading Facial Emotion Model (ViT)...")
|
| 273 |
+
# Keep the pretrained ViT backbone but swap in the emotion-class head size.
|
| 274 |
+
facial_processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
|
| 275 |
+
vit_model = AutoModelForImageClassification.from_pretrained(
|
| 276 |
+
'google/vit-base-patch16-224-in21k',
|
| 277 |
+
num_labels=len(EMOTIONS_FACIAL),
|
| 278 |
+
ignore_mismatched_sizes=True,
|
| 279 |
+
attn_implementation='eager'
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Load either a full checkpoint or a plain state_dict depending on how the file was saved.
|
| 283 |
+
checkpoint = torch.load(FACIAL_MODEL_PATH, map_location=DEVICE)
|
| 284 |
+
if 'model_state_dict' in checkpoint:
|
| 285 |
+
vit_model.load_state_dict(checkpoint['model_state_dict'])
|
| 286 |
+
else:
|
| 287 |
+
vit_model.load_state_dict(checkpoint)
|
| 288 |
+
logger.info("✓ Loaded ViT checkpoint")
|
| 289 |
+
|
| 290 |
+
vit_model = vit_model.to(DEVICE)
|
| 291 |
+
vit_model.eval()
|
| 292 |
+
facial_loaded = True
|
| 293 |
+
logger.info("✓ Facial model ready")
|
| 294 |
+
return True
|
| 295 |
+
except Exception as e:
|
| 296 |
+
facial_loaded = False
|
| 297 |
+
logger.error(f"❌ Error loading facial model: {e}")
|
| 298 |
+
return False
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def load_speech_model():
|
| 302 |
+
"""Load HuBERT model for speech emotion"""
|
| 303 |
+
global speech_model, speech_processor, speech_loaded
|
| 304 |
+
if speech_model is not None and speech_processor is not None:
|
| 305 |
+
speech_loaded = True
|
| 306 |
+
return True
|
| 307 |
+
|
| 308 |
+
with _speech_model_lock:
|
| 309 |
+
if speech_model is not None and speech_processor is not None:
|
| 310 |
+
speech_loaded = True
|
| 311 |
+
return True
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
logger.info("Loading Speech Emotion Model (HuBERT)...")
|
| 315 |
+
# Match the pretrained audio backbone to the project-specific emotion label set.
|
| 316 |
+
speech_processor = AutoFeatureExtractor.from_pretrained('facebook/hubert-large-ls960-ft')
|
| 317 |
+
speech_model = AutoModelForAudioClassification.from_pretrained(
|
| 318 |
+
'facebook/hubert-large-ls960-ft',
|
| 319 |
+
num_labels=len(EMOTIONS_SPEECH),
|
| 320 |
+
ignore_mismatched_sizes=True
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Support both checkpoint formats used across training experiments.
|
| 324 |
+
checkpoint = torch.load(SPEECH_MODEL_PATH, map_location=DEVICE)
|
| 325 |
+
if 'model_state_dict' in checkpoint:
|
| 326 |
+
speech_model.load_state_dict(checkpoint['model_state_dict'])
|
| 327 |
+
else:
|
| 328 |
+
speech_model.load_state_dict(checkpoint)
|
| 329 |
+
logger.info("✓ Loaded HuBERT checkpoint")
|
| 330 |
+
|
| 331 |
+
speech_model = speech_model.to(DEVICE)
|
| 332 |
+
speech_model.eval()
|
| 333 |
+
speech_loaded = True
|
| 334 |
+
logger.info("✓ Speech model ready")
|
| 335 |
+
return True
|
| 336 |
+
except Exception as e:
|
| 337 |
+
speech_loaded = False
|
| 338 |
+
logger.error(f"❌ Error loading speech model: {e}")
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def ensure_facial_model_loaded() -> bool:
|
| 343 |
+
if vit_model is not None and facial_processor is not None:
|
| 344 |
+
return True
|
| 345 |
+
return load_facial_model()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def ensure_speech_model_loaded() -> bool:
|
| 349 |
+
if speech_model is not None and speech_processor is not None:
|
| 350 |
+
return True
|
| 351 |
+
return load_speech_model()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# Optional eager loading for environments that prefer warm startup.
|
| 355 |
+
if PRELOAD_MODELS:
|
| 356 |
+
facial_loaded = load_facial_model()
|
| 357 |
+
speech_loaded = load_speech_model()
|
| 358 |
+
|
| 359 |
+
# ========== VIDEO PROCESSOR ==========
|
| 360 |
+
|
| 361 |
+
class VideoProcessor:
|
| 362 |
+
@staticmethod
|
| 363 |
+
def extract_frames_and_audio(video_path: str, fps_sample: int = 5):
|
| 364 |
+
"""Extract frames and audio from video"""
|
| 365 |
+
frames = []
|
| 366 |
+
cap = cv2.VideoCapture(video_path)
|
| 367 |
+
|
| 368 |
+
if not cap.isOpened():
|
| 369 |
+
raise ValueError(f"Cannot open video: {video_path}")
|
| 370 |
+
|
| 371 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 372 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 373 |
+
if fps <= 0 or fps > 120:
|
| 374 |
+
fps = 30.0
|
| 375 |
+
|
| 376 |
+
frame_count = 0
|
| 377 |
+
while cap.isOpened():
|
| 378 |
+
ret, frame = cap.read()
|
| 379 |
+
if not ret:
|
| 380 |
+
break
|
| 381 |
+
|
| 382 |
+
if frame_count % fps_sample == 0:
|
| 383 |
+
# Sample every Nth frame so we analyze representative facial expressions without processing the full video.
|
| 384 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 385 |
+
frames.append(Image.fromarray(frame_rgb))
|
| 386 |
+
|
| 387 |
+
frame_count += 1
|
| 388 |
+
|
| 389 |
+
cap.release()
|
| 390 |
+
|
| 391 |
+
# librosa reads the audio track directly from the same file, giving us a single mono stream for speech inference.
|
| 392 |
+
audio, sr = librosa.load(video_path, sr=16000, mono=True)
|
| 393 |
+
|
| 394 |
+
return frames, audio, sr, fps
|
| 395 |
+
|
| 396 |
+
# ========== PREDICTION FUNCTIONS ==========
|
| 397 |
+
|
| 398 |
+
def predict_facial_emotion(image: Image.Image, generate_explainability: bool = False):
|
| 399 |
+
"""Predict emotion from image"""
|
| 400 |
+
try:
|
| 401 |
+
if not ensure_facial_model_loaded():
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
# Normalize EXIF orientation first so mobile uploads and camera captures behave consistently.
|
| 405 |
+
image = ImageOps.exif_transpose(image).convert('RGB')
|
| 406 |
+
|
| 407 |
+
# Detect the most likely face before deciding whether to crop or rotate the input.
|
| 408 |
+
detected = _detect_primary_face(image)
|
| 409 |
+
face_box, face_points = detected if isinstance(detected, tuple) else (None, None)
|
| 410 |
+
|
| 411 |
+
# If we have eye landmarks, try a small rotation pass to correct head tilt.
|
| 412 |
+
rotated_image = _rotate_image_to_level(image, face_points)
|
| 413 |
+
if rotated_image is not image:
|
| 414 |
+
rotated_detected = _detect_primary_face(rotated_image)
|
| 415 |
+
if isinstance(rotated_detected, tuple):
|
| 416 |
+
rotated_box, rotated_points = rotated_detected
|
| 417 |
+
if rotated_box is not None:
|
| 418 |
+
image = rotated_image
|
| 419 |
+
face_box = rotated_box
|
| 420 |
+
face_points = rotated_points
|
| 421 |
+
|
| 422 |
+
input_array = np.array(image)
|
| 423 |
+
|
| 424 |
+
model_image = image
|
| 425 |
+
|
| 426 |
+
# Crop to the detected face when possible so the classifier sees the most relevant region.
|
| 427 |
+
if face_box is not None:
|
| 428 |
+
face_crop, _ = _crop_face_with_margin(input_array, face_box)
|
| 429 |
+
if face_crop.size > 0:
|
| 430 |
+
model_image = Image.fromarray(face_crop)
|
| 431 |
+
|
| 432 |
+
# Draw the face box on the preview image to make the detection step visible to the user.
|
| 433 |
+
annotated = input_array.copy()
|
| 434 |
+
if face_box is not None:
|
| 435 |
+
x, y, w, h = _shrink_box(face_box, shrink_ratio=0.08)
|
| 436 |
+
cv2.rectangle(annotated, (x, y), (x + w, y + h), (255, 128, 0), 2)
|
| 437 |
+
cv2.putText(
|
| 438 |
+
annotated,
|
| 439 |
+
'Face detected',
|
| 440 |
+
(x, max(20, y - 8)),
|
| 441 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 442 |
+
0.6,
|
| 443 |
+
(255, 128, 0),
|
| 444 |
+
2,
|
| 445 |
+
cv2.LINE_AA
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
inputs = facial_processor(model_image, return_tensors='pt').to(DEVICE)
|
| 449 |
+
with torch.no_grad():
|
| 450 |
+
outputs = vit_model(**inputs)
|
| 451 |
+
logits = outputs.logits.cpu().numpy()[0]
|
| 452 |
+
# Convert raw logits into probabilities for easier interpretation in the UI.
|
| 453 |
+
probs = torch.softmax(torch.from_numpy(logits), dim=0).numpy()
|
| 454 |
+
|
| 455 |
+
top_idx = np.argmax(probs)
|
| 456 |
+
result = {
|
| 457 |
+
"emotion": EMOTIONS_FACIAL[top_idx],
|
| 458 |
+
"confidence": float(probs[top_idx]),
|
| 459 |
+
"probabilities": {e: float(p) for e, p in zip(EMOTIONS_FACIAL, probs)},
|
| 460 |
+
"face_detected": face_box is not None,
|
| 461 |
+
"annotated_image": _encode_image_base64(annotated)
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
if face_box is not None:
|
| 465 |
+
x, y, w, h = [int(v) for v in face_box]
|
| 466 |
+
result["face_box"] = {"x": x, "y": y, "width": w, "height": h}
|
| 467 |
+
|
| 468 |
+
if generate_explainability:
|
| 469 |
+
# Explainability is optional because Grad-CAM adds compute cost.
|
| 470 |
+
result["explainability_status"] = {
|
| 471 |
+
"requested": True,
|
| 472 |
+
"generated": False,
|
| 473 |
+
"error": None
|
| 474 |
+
}
|
| 475 |
+
try:
|
| 476 |
+
original_base64, heatmap_base64 = generate_grad_cam(
|
| 477 |
+
model_image,
|
| 478 |
+
vit_model,
|
| 479 |
+
facial_processor,
|
| 480 |
+
top_idx,
|
| 481 |
+
EMOTIONS_FACIAL,
|
| 482 |
+
DEVICE
|
| 483 |
+
)
|
| 484 |
+
if original_base64:
|
| 485 |
+
result["original_image"] = original_base64
|
| 486 |
+
if heatmap_base64:
|
| 487 |
+
result["grad_cam"] = heatmap_base64
|
| 488 |
+
result["explainability_status"]["generated"] = True
|
| 489 |
+
else:
|
| 490 |
+
result["explainability_status"]["error"] = "Grad-CAM map returned empty output"
|
| 491 |
+
except Exception as e:
|
| 492 |
+
logger.warning(f"Could not generate Grad-CAM: {e}")
|
| 493 |
+
result["explainability_status"]["error"] = str(e)
|
| 494 |
+
|
| 495 |
+
return result
|
| 496 |
+
except Exception as e:
|
| 497 |
+
logger.error(f"Error predicting facial emotion: {e}")
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
def predict_speech_emotion(audio: np.ndarray, sr: int = 16000, generate_explainability: bool = False):
|
| 501 |
+
"""Predict emotion from audio"""
|
| 502 |
+
try:
|
| 503 |
+
if not ensure_speech_model_loaded():
|
| 504 |
+
return None
|
| 505 |
+
|
| 506 |
+
if sr != 16000:
|
| 507 |
+
# Resample every input to the model's expected sampling rate.
|
| 508 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
| 509 |
+
|
| 510 |
+
# Keep inference fast and stable for long recordings.
|
| 511 |
+
audio_for_infer = _trim_audio_window(audio, 16000, MAX_SPEECH_INFER_SECONDS)
|
| 512 |
+
|
| 513 |
+
inputs = speech_processor(audio_for_infer, sampling_rate=16000, return_tensors="pt", padding=True)
|
| 514 |
+
with torch.no_grad():
|
| 515 |
+
outputs = speech_model(inputs['input_values'].to(DEVICE))
|
| 516 |
+
logits = outputs.logits.cpu().numpy()[0]
|
| 517 |
+
# Softmax keeps the output distribution easy to display and compare.
|
| 518 |
+
probs = np.exp(logits) / np.sum(np.exp(logits))
|
| 519 |
+
|
| 520 |
+
top_idx = np.argmax(probs)
|
| 521 |
+
result = {
|
| 522 |
+
"emotion": EMOTIONS_SPEECH[top_idx],
|
| 523 |
+
"confidence": float(probs[top_idx]),
|
| 524 |
+
"probabilities": {e: float(p) for e, p in zip(EMOTIONS_SPEECH, probs)}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
if generate_explainability:
|
| 528 |
+
# Saliency is computed on a shorter slice to avoid long XAI runs on large clips.
|
| 529 |
+
result["explainability_status"] = {
|
| 530 |
+
"requested": True,
|
| 531 |
+
"generated": False,
|
| 532 |
+
"error": None
|
| 533 |
+
}
|
| 534 |
+
try:
|
| 535 |
+
# Saliency on a shorter centered chunk avoids multi-minute stalls.
|
| 536 |
+
audio_for_xai = _trim_audio_window(audio_for_infer, 16000, MAX_SPEECH_XAI_SECONDS)
|
| 537 |
+
spec_base64, saliency_base64 = generate_audio_saliency(
|
| 538 |
+
audio_for_xai,
|
| 539 |
+
speech_model,
|
| 540 |
+
speech_processor,
|
| 541 |
+
top_idx,
|
| 542 |
+
EMOTIONS_SPEECH,
|
| 543 |
+
DEVICE,
|
| 544 |
+
sr=16000
|
| 545 |
+
)
|
| 546 |
+
if spec_base64:
|
| 547 |
+
result["waveform"] = spec_base64
|
| 548 |
+
if saliency_base64:
|
| 549 |
+
result["saliency"] = saliency_base64
|
| 550 |
+
result["explainability_status"]["generated"] = True
|
| 551 |
+
else:
|
| 552 |
+
result["explainability_status"]["error"] = "Audio saliency map returned empty output"
|
| 553 |
+
except Exception as e:
|
| 554 |
+
logger.warning(f"Could not generate audio saliency: {e}")
|
| 555 |
+
result["explainability_status"]["error"] = str(e)
|
| 556 |
+
|
| 557 |
+
return result
|
| 558 |
+
except Exception as e:
|
| 559 |
+
logger.error(f"Error predicting speech emotion: {e}")
|
| 560 |
+
return None
|
| 561 |
+
|
| 562 |
+
# ========== API ENDPOINTS ==========
|
| 563 |
+
|
| 564 |
+
@app.get("/")
|
| 565 |
+
async def root():
|
| 566 |
+
return {"message": "Multi-Modal Emotion Recognition API v2.0", "status": "active"}
|
| 567 |
+
|
| 568 |
+
@app.get("/health")
|
| 569 |
+
async def health():
|
| 570 |
+
facial_ready = vit_model is not None and facial_processor is not None
|
| 571 |
+
speech_ready = speech_model is not None and speech_processor is not None
|
| 572 |
+
return {
|
| 573 |
+
"status": "healthy",
|
| 574 |
+
"facial_model": facial_ready,
|
| 575 |
+
"speech_model": speech_ready,
|
| 576 |
+
"lazy_loading": not PRELOAD_MODELS,
|
| 577 |
+
"device": str(DEVICE)
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
@app.post("/api/predict/facial")
|
| 581 |
+
async def predict_facial(file: UploadFile = File(...), explain: bool = False):
|
| 582 |
+
"""Predict emotion from image"""
|
| 583 |
+
try:
|
| 584 |
+
logger.info(f"Received file: {file.filename}, content_type: {file.content_type}")
|
| 585 |
+
contents = await file.read()
|
| 586 |
+
logger.info(f"File size: {len(contents)} bytes")
|
| 587 |
+
if len(contents) == 0:
|
| 588 |
+
return JSONResponse(status_code=400, content={"error": "Empty file received"})
|
| 589 |
+
image = ImageOps.exif_transpose(Image.open(BytesIO(contents))).convert('RGB')
|
| 590 |
+
result = predict_facial_emotion(image, generate_explainability=explain)
|
| 591 |
+
return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
|
| 592 |
+
except Exception as e:
|
| 593 |
+
logger.error(f"Error in predict_facial: {e}", exc_info=True)
|
| 594 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 595 |
+
|
| 596 |
+
@app.post("/api/predict/speech")
|
| 597 |
+
async def predict_speech(file: UploadFile = File(...), explain: bool = False):
|
| 598 |
+
"""Predict emotion from audio"""
|
| 599 |
+
try:
|
| 600 |
+
contents = await file.read()
|
| 601 |
+
suffix = _upload_suffix(file.filename, '.wav')
|
| 602 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 603 |
+
tmp.write(contents)
|
| 604 |
+
tmp_path = tmp.name
|
| 605 |
+
|
| 606 |
+
try:
|
| 607 |
+
audio, sr = librosa.load(tmp_path, sr=16000)
|
| 608 |
+
result = predict_speech_emotion(audio, sr, generate_explainability=explain)
|
| 609 |
+
return {"success": True, **result} if result else {"success": False, "error": "Prediction failed"}
|
| 610 |
+
finally:
|
| 611 |
+
os.unlink(tmp_path)
|
| 612 |
+
except Exception as e:
|
| 613 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 614 |
+
|
| 615 |
+
@app.post("/api/predict/combined")
|
| 616 |
+
async def predict_combined(image_file: UploadFile = File(...), audio_file: UploadFile = File(...), explain: bool = False):
|
| 617 |
+
"""Predict emotions from both image and audio, then compare results"""
|
| 618 |
+
try:
|
| 619 |
+
image_contents = await image_file.read()
|
| 620 |
+
image = ImageOps.exif_transpose(Image.open(BytesIO(image_contents))).convert('RGB')
|
| 621 |
+
facial_result = predict_facial_emotion(image, generate_explainability=explain)
|
| 622 |
+
|
| 623 |
+
audio_contents = await audio_file.read()
|
| 624 |
+
audio_suffix = _upload_suffix(audio_file.filename, '.wav')
|
| 625 |
+
with tempfile.NamedTemporaryFile(suffix=audio_suffix, delete=False) as tmp:
|
| 626 |
+
tmp.write(audio_contents)
|
| 627 |
+
tmp_path = tmp.name
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
audio, sr = librosa.load(tmp_path, sr=16000)
|
| 631 |
+
speech_result = predict_speech_emotion(audio, sr, generate_explainability=explain)
|
| 632 |
+
finally:
|
| 633 |
+
os.unlink(tmp_path)
|
| 634 |
+
|
| 635 |
+
facial_emotion = facial_result["emotion"] if facial_result else None
|
| 636 |
+
facial_confidence = facial_result["confidence"] if facial_result else 0.0
|
| 637 |
+
|
| 638 |
+
speech_emotion = speech_result["emotion"] if speech_result else None
|
| 639 |
+
speech_confidence = speech_result["confidence"] if speech_result else 0.0
|
| 640 |
+
|
| 641 |
+
concordance, concordance_score = _calculate_concordance(
|
| 642 |
+
facial_emotion,
|
| 643 |
+
speech_emotion,
|
| 644 |
+
facial_confidence,
|
| 645 |
+
speech_confidence,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# The combined label should prefer the more confident modality when both are present.
|
| 649 |
+
combined_emotion = None
|
| 650 |
+
combined_confidence = 0.0
|
| 651 |
+
|
| 652 |
+
if facial_emotion and speech_emotion:
|
| 653 |
+
if facial_confidence > speech_confidence:
|
| 654 |
+
combined_emotion = facial_emotion
|
| 655 |
+
combined_confidence = facial_confidence
|
| 656 |
+
else:
|
| 657 |
+
combined_emotion = speech_emotion
|
| 658 |
+
combined_confidence = speech_confidence
|
| 659 |
+
elif facial_emotion:
|
| 660 |
+
combined_emotion = facial_emotion
|
| 661 |
+
combined_confidence = facial_confidence
|
| 662 |
+
elif speech_emotion:
|
| 663 |
+
combined_emotion = speech_emotion
|
| 664 |
+
combined_confidence = speech_confidence
|
| 665 |
+
|
| 666 |
+
response = {
|
| 667 |
+
"success": True,
|
| 668 |
+
"facial_emotion": {
|
| 669 |
+
"emotion": facial_emotion or "unknown",
|
| 670 |
+
"confidence": float(facial_confidence),
|
| 671 |
+
"probabilities": facial_result["probabilities"] if facial_result else {},
|
| 672 |
+
"face_detected": facial_result.get("face_detected", False) if facial_result else False,
|
| 673 |
+
"face_box": facial_result.get("face_box") if facial_result else None,
|
| 674 |
+
"annotated_image": facial_result.get("annotated_image") if facial_result else None
|
| 675 |
+
},
|
| 676 |
+
"speech_emotion": {
|
| 677 |
+
"emotion": speech_emotion or "unknown",
|
| 678 |
+
"confidence": float(speech_confidence),
|
| 679 |
+
"probabilities": speech_result["probabilities"] if speech_result else {}
|
| 680 |
+
},
|
| 681 |
+
"combined_emotion": combined_emotion or "unknown",
|
| 682 |
+
"combined_confidence": float(combined_confidence),
|
| 683 |
+
"concordance": concordance,
|
| 684 |
+
"concordance_score": concordance_score,
|
| 685 |
+
"analysis": {
|
| 686 |
+
"match": concordance == "MATCH",
|
| 687 |
+
"agreement_details": f"Face: {facial_emotion} (conf: {facial_confidence:.2f}) | Voice: {speech_emotion} (conf: {speech_confidence:.2f})"
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
if explain:
|
| 692 |
+
# Keep the response shape stable even when one modality fails to generate XAI output.
|
| 693 |
+
explainability = {}
|
| 694 |
+
errors = []
|
| 695 |
+
|
| 696 |
+
facial_status = (facial_result or {}).get("explainability_status") or {
|
| 697 |
+
"requested": True,
|
| 698 |
+
"generated": False,
|
| 699 |
+
"error": "Facial explainability unavailable"
|
| 700 |
+
}
|
| 701 |
+
speech_status = (speech_result or {}).get("explainability_status") or {
|
| 702 |
+
"requested": True,
|
| 703 |
+
"generated": False,
|
| 704 |
+
"error": "Speech explainability unavailable"
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
if facial_result and facial_result.get("grad_cam"):
|
| 708 |
+
explainability["grad_cam"] = facial_result.get("grad_cam")
|
| 709 |
+
elif facial_status.get("error"):
|
| 710 |
+
errors.append(f"Facial: {facial_status.get('error')}")
|
| 711 |
+
|
| 712 |
+
if speech_result and speech_result.get("saliency"):
|
| 713 |
+
explainability["saliency"] = speech_result.get("saliency")
|
| 714 |
+
elif speech_status.get("error"):
|
| 715 |
+
errors.append(f"Speech: {speech_status.get('error')}")
|
| 716 |
+
|
| 717 |
+
if speech_result and speech_result.get("waveform"):
|
| 718 |
+
explainability["waveform"] = speech_result.get("waveform")
|
| 719 |
+
|
| 720 |
+
response["explainability_status"] = {
|
| 721 |
+
"requested": True,
|
| 722 |
+
"generated": bool(explainability),
|
| 723 |
+
"facial": facial_status,
|
| 724 |
+
"speech": speech_status,
|
| 725 |
+
"errors": errors
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
if explainability:
|
| 729 |
+
response["explainability"] = explainability
|
| 730 |
+
|
| 731 |
+
return response
|
| 732 |
+
except Exception as e:
|
| 733 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 734 |
+
|
| 735 |
+
@app.post("/api/predict/video")
|
| 736 |
+
async def predict_video_emotion(file: UploadFile = File(...), explain: bool = False):
|
| 737 |
+
"""Predict emotions from video (facial + speech)"""
|
| 738 |
+
try:
|
| 739 |
+
video_suffix = _upload_suffix(file.filename, '.mp4')
|
| 740 |
+
with tempfile.NamedTemporaryFile(suffix=video_suffix, delete=False) as tmp:
|
| 741 |
+
contents = await file.read()
|
| 742 |
+
tmp.write(contents)
|
| 743 |
+
tmp_path = tmp.name
|
| 744 |
+
|
| 745 |
+
try:
|
| 746 |
+
processor = VideoProcessor()
|
| 747 |
+
frames, audio, sr, fps = processor.extract_frames_and_audio(tmp_path, fps_sample=5)
|
| 748 |
+
|
| 749 |
+
facial_results = []
|
| 750 |
+
for frame in frames[:10]:
|
| 751 |
+
result = predict_facial_emotion(frame)
|
| 752 |
+
if result:
|
| 753 |
+
facial_results.append(result)
|
| 754 |
+
|
| 755 |
+
if facial_results:
|
| 756 |
+
facial_emotions = [r["emotion"] for r in facial_results]
|
| 757 |
+
facial_confidence = np.mean([r["confidence"] for r in facial_results])
|
| 758 |
+
facial_emotion = max(set(facial_emotions), key=facial_emotions.count)
|
| 759 |
+
facial_probs = {}
|
| 760 |
+
for emotion in EMOTIONS_FACIAL:
|
| 761 |
+
facial_probs[emotion] = float(np.mean([r["probabilities"].get(emotion, 0) for r in facial_results]))
|
| 762 |
+
else:
|
| 763 |
+
facial_emotion = "unknown"
|
| 764 |
+
facial_confidence = 0.0
|
| 765 |
+
facial_probs = {e: 0.0 for e in EMOTIONS_FACIAL}
|
| 766 |
+
|
| 767 |
+
speech_result = predict_speech_emotion(audio, sr)
|
| 768 |
+
speech_emotion = speech_result["emotion"] if speech_result else "unknown"
|
| 769 |
+
speech_confidence = float(speech_result["confidence"]) if speech_result else 0.0
|
| 770 |
+
concordance, concordance_score = _calculate_concordance(
|
| 771 |
+
facial_emotion,
|
| 772 |
+
speech_emotion,
|
| 773 |
+
facial_confidence,
|
| 774 |
+
speech_confidence,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
response = {
|
| 778 |
+
"success": True,
|
| 779 |
+
"facial_emotion": {
|
| 780 |
+
"emotion": facial_emotion,
|
| 781 |
+
"confidence": float(facial_confidence),
|
| 782 |
+
"frames_analyzed": len(facial_results),
|
| 783 |
+
"probabilities": facial_probs
|
| 784 |
+
},
|
| 785 |
+
"speech_emotion": {
|
| 786 |
+
"emotion": speech_emotion,
|
| 787 |
+
"confidence": speech_confidence,
|
| 788 |
+
"probabilities": speech_result["probabilities"] if speech_result else {e: 0.0 for e in EMOTIONS_SPEECH}
|
| 789 |
+
},
|
| 790 |
+
"combined_emotion": facial_emotion if facial_confidence > 0.5 else (speech_result["emotion"] if speech_result else "unknown"),
|
| 791 |
+
"concordance": concordance,
|
| 792 |
+
"concordance_score": concordance_score,
|
| 793 |
+
"video_duration": float(len(audio) / sr),
|
| 794 |
+
"frames_processed": len(frames),
|
| 795 |
+
"fps": float(fps)
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
if explain:
|
| 799 |
+
explainability = {}
|
| 800 |
+
errors = []
|
| 801 |
+
|
| 802 |
+
facial_exp_status = {"requested": True, "generated": False, "error": None}
|
| 803 |
+
speech_exp_status = {"requested": True, "generated": False, "error": None}
|
| 804 |
+
|
| 805 |
+
if frames and facial_emotion != "unknown":
|
| 806 |
+
try:
|
| 807 |
+
# Run GradCAM on the best frame that predicted the aggregated facial_emotion
|
| 808 |
+
best_frame = None
|
| 809 |
+
best_result = None
|
| 810 |
+
best_conf = 0
|
| 811 |
+
for frame in frames[:10]:
|
| 812 |
+
r = predict_facial_emotion(frame)
|
| 813 |
+
# Find the frame that predicted the aggregated emotion with highest confidence
|
| 814 |
+
if r and r.get("emotion") == facial_emotion and r.get("confidence", 0) > best_conf:
|
| 815 |
+
best_conf = r["confidence"]
|
| 816 |
+
best_frame = frame
|
| 817 |
+
best_result = r
|
| 818 |
+
|
| 819 |
+
# If no frame predicted the aggregated emotion, use the first frame
|
| 820 |
+
if best_frame is None and frames:
|
| 821 |
+
best_frame = frames[0]
|
| 822 |
+
best_result = predict_facial_emotion(best_frame)
|
| 823 |
+
|
| 824 |
+
if best_frame is not None:
|
| 825 |
+
top_idx = EMOTIONS_FACIAL.index(facial_emotion) \
|
| 826 |
+
if facial_emotion in EMOTIONS_FACIAL else 0
|
| 827 |
+
# Crop face before passing to GradCAM
|
| 828 |
+
face_box, _ = _detect_primary_face(best_frame)
|
| 829 |
+
if face_box is not None:
|
| 830 |
+
frame_array = np.array(best_frame)
|
| 831 |
+
face_crop_array, _ = _crop_face_with_margin(frame_array, face_box)
|
| 832 |
+
gradcam_input = Image.fromarray(face_crop_array) if face_crop_array.size > 0 else best_frame
|
| 833 |
+
else:
|
| 834 |
+
gradcam_input = best_frame
|
| 835 |
+
orig_b64, heatmap_b64 = generate_grad_cam(
|
| 836 |
+
gradcam_input, vit_model, facial_processor,
|
| 837 |
+
top_idx, EMOTIONS_FACIAL, DEVICE
|
| 838 |
+
)
|
| 839 |
+
if heatmap_b64:
|
| 840 |
+
explainability["grad_cam"] = heatmap_b64
|
| 841 |
+
facial_exp_status["generated"] = True
|
| 842 |
+
else:
|
| 843 |
+
facial_exp_status["error"] = "GradCAM returned empty output"
|
| 844 |
+
except Exception as e:
|
| 845 |
+
facial_exp_status["error"] = str(e)
|
| 846 |
+
else:
|
| 847 |
+
facial_exp_status["error"] = "No valid frame prediction found for facial explainability"
|
| 848 |
+
|
| 849 |
+
if speech_result and speech_emotion != "unknown":
|
| 850 |
+
try:
|
| 851 |
+
top_idx = EMOTIONS_SPEECH.index(speech_emotion) \
|
| 852 |
+
if speech_emotion in EMOTIONS_SPEECH else 0
|
| 853 |
+
audio_for_xai = _trim_audio_window(audio, sr, max_seconds=MAX_SPEECH_XAI_SECONDS)
|
| 854 |
+
spec_b64, saliency_b64 = generate_audio_saliency(
|
| 855 |
+
audio_for_xai,
|
| 856 |
+
speech_model,
|
| 857 |
+
speech_processor,
|
| 858 |
+
top_idx,
|
| 859 |
+
EMOTIONS_SPEECH,
|
| 860 |
+
DEVICE,
|
| 861 |
+
sr=16000
|
| 862 |
+
)
|
| 863 |
+
if spec_b64:
|
| 864 |
+
explainability["waveform"] = spec_b64
|
| 865 |
+
if saliency_b64:
|
| 866 |
+
explainability["saliency"] = saliency_b64
|
| 867 |
+
speech_exp_status["generated"] = True
|
| 868 |
+
else:
|
| 869 |
+
speech_exp_status["error"] = "Audio saliency map returned empty output"
|
| 870 |
+
except Exception as e:
|
| 871 |
+
speech_exp_status["error"] = str(e)
|
| 872 |
+
else:
|
| 873 |
+
speech_exp_status["error"] = "No valid audio prediction found for explainability"
|
| 874 |
+
|
| 875 |
+
if facial_exp_status.get("error"):
|
| 876 |
+
errors.append(f"Facial: {facial_exp_status.get('error')}")
|
| 877 |
+
if speech_exp_status.get("error"):
|
| 878 |
+
errors.append(f"Speech: {speech_exp_status.get('error')}")
|
| 879 |
+
|
| 880 |
+
response["explainability_status"] = {
|
| 881 |
+
"requested": True,
|
| 882 |
+
"generated": bool(explainability),
|
| 883 |
+
"facial": facial_exp_status,
|
| 884 |
+
"speech": speech_exp_status,
|
| 885 |
+
"errors": errors
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
if explainability:
|
| 889 |
+
response["explainability"] = explainability
|
| 890 |
+
|
| 891 |
+
return response
|
| 892 |
+
finally:
|
| 893 |
+
os.unlink(tmp_path)
|
| 894 |
+
except Exception as e:
|
| 895 |
+
return JSONResponse(status_code=400, content={"error": str(e)})
|
| 896 |
+
|
| 897 |
+
@app.get("/api/emotions/facial")
|
| 898 |
+
async def get_facial_emotions():
|
| 899 |
+
return {"emotions": EMOTIONS_FACIAL}
|
| 900 |
+
|
| 901 |
+
@app.get("/api/emotions/speech")
|
| 902 |
+
async def get_speech_emotions():
|
| 903 |
+
return {"emotions": EMOTIONS_SPEECH}
|
| 904 |
+
|
| 905 |
+
@app.get("/api/models/status")
|
| 906 |
+
async def get_models_status():
|
| 907 |
+
facial_ready = vit_model is not None and facial_processor is not None
|
| 908 |
+
speech_ready = speech_model is not None and speech_processor is not None
|
| 909 |
+
return {
|
| 910 |
+
"facial": {"loaded": facial_ready, "accuracy": 0.7129, "emotions": len(EMOTIONS_FACIAL)},
|
| 911 |
+
"speech": {"loaded": speech_ready, "accuracy": 0.8750, "emotions": len(EMOTIONS_SPEECH)},
|
| 912 |
+
"lazy_loading": not PRELOAD_MODELS,
|
| 913 |
+
"device": str(DEVICE)
|
| 914 |
+
}
|
backend/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Service-layer utilities for backend inference and explainability."""
|
backend/services/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (263 Bytes). View file
|
|
|
backend/services/__pycache__/explainability.cpython-314.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
backend/services/data_loader.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loaders used by the training and experimentation workflows."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import librosa
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ==================== FACIAL DATASET ====================
|
| 13 |
+
class FER2013Dataset(Dataset):
|
| 14 |
+
"""FER2013 facial emotion dataset loader."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, root_dir: str, split: str = "train", transform=None):
|
| 17 |
+
"""
|
| 18 |
+
Initialize FER2013 dataset.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
root_dir: Root directory containing 'train' and 'test' folders
|
| 22 |
+
split: 'train' or 'test'
|
| 23 |
+
transform: Torchvision transforms to apply
|
| 24 |
+
"""
|
| 25 |
+
self.root_dir = root_dir
|
| 26 |
+
self.split = split
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.emotions = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
| 29 |
+
self.emotion2idx = {e: i for i, e in enumerate(self.emotions)}
|
| 30 |
+
|
| 31 |
+
self.samples = []
|
| 32 |
+
self._load_samples()
|
| 33 |
+
|
| 34 |
+
def _load_samples(self):
|
| 35 |
+
"""Load all image paths and labels."""
|
| 36 |
+
split_dir = os.path.join(self.root_dir, self.split)
|
| 37 |
+
|
| 38 |
+
for emotion in self.emotions:
|
| 39 |
+
emotion_dir = os.path.join(split_dir, emotion)
|
| 40 |
+
if not os.path.exists(emotion_dir):
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
for img_file in os.listdir(emotion_dir):
|
| 44 |
+
if img_file.endswith(('.jpg', '.jpeg', '.png')):
|
| 45 |
+
img_path = os.path.join(emotion_dir, img_file)
|
| 46 |
+
self.samples.append((img_path, self.emotion2idx[emotion]))
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.samples)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
img_path, label = self.samples[idx]
|
| 53 |
+
|
| 54 |
+
# Load image
|
| 55 |
+
image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
| 56 |
+
if image is None:
|
| 57 |
+
return torch.zeros(3, 224, 224), torch.tensor(label, dtype=torch.long)
|
| 58 |
+
|
| 59 |
+
# Convert to RGB (3 channels)
|
| 60 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 61 |
+
|
| 62 |
+
if self.transform:
|
| 63 |
+
image = self.transform(image)
|
| 64 |
+
else:
|
| 65 |
+
# Default transform
|
| 66 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
|
| 67 |
+
|
| 68 |
+
return image, torch.tensor(label, dtype=torch.long)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ==================== AUDIO DATASET ====================
|
| 72 |
+
class RAVDESSDataset(Dataset):
|
| 73 |
+
"""RAVDESS audio emotion dataset loader."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, root_dir: str, n_mfcc: int = 13, target_sr: int = 22050):
|
| 76 |
+
"""
|
| 77 |
+
Initialize RAVDESS dataset.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
root_dir: Root directory containing audio files
|
| 81 |
+
n_mfcc: Number of MFCCs to extract
|
| 82 |
+
target_sr: Target sampling rate
|
| 83 |
+
"""
|
| 84 |
+
self.root_dir = root_dir
|
| 85 |
+
self.n_mfcc = n_mfcc
|
| 86 |
+
self.target_sr = target_sr
|
| 87 |
+
self.emotion_map = {
|
| 88 |
+
'01': 'neutral',
|
| 89 |
+
'02': 'calm',
|
| 90 |
+
'03': 'happy',
|
| 91 |
+
'04': 'sad',
|
| 92 |
+
'05': 'angry',
|
| 93 |
+
'06': 'fear',
|
| 94 |
+
'07': 'disgust',
|
| 95 |
+
'08': 'surprise'
|
| 96 |
+
}
|
| 97 |
+
self.emotion2idx = {v: i for i, v in enumerate(set(self.emotion_map.values()))}
|
| 98 |
+
|
| 99 |
+
self.samples = []
|
| 100 |
+
self._load_samples()
|
| 101 |
+
|
| 102 |
+
def _load_samples(self):
|
| 103 |
+
"""Load all audio file paths and labels."""
|
| 104 |
+
for file in os.listdir(self.root_dir):
|
| 105 |
+
if file.endswith('.wav'):
|
| 106 |
+
emotion_code = file.split('-')[2]
|
| 107 |
+
if emotion_code in self.emotion_map:
|
| 108 |
+
emotion = self.emotion_map[emotion_code]
|
| 109 |
+
audio_path = os.path.join(self.root_dir, file)
|
| 110 |
+
self.samples.append((audio_path, self.emotion2idx[emotion]))
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.samples)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, idx):
|
| 116 |
+
audio_path, label = self.samples[idx]
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
y, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
|
| 120 |
+
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)
|
| 121 |
+
|
| 122 |
+
# Normalize MFCC
|
| 123 |
+
mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
|
| 124 |
+
|
| 125 |
+
# Pad or truncate to fixed size (100 time steps)
|
| 126 |
+
if mfcc.shape[1] < 100:
|
| 127 |
+
mfcc = np.pad(mfcc, ((0, 0), (0, 100 - mfcc.shape[1])), mode='constant')
|
| 128 |
+
else:
|
| 129 |
+
mfcc = mfcc[:, :100]
|
| 130 |
+
|
| 131 |
+
return torch.from_numpy(mfcc).float(), torch.tensor(label, dtype=torch.long)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error loading {audio_path}: {e}")
|
| 134 |
+
return torch.zeros(self.n_mfcc, 100), torch.tensor(label, dtype=torch.long)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ==================== DATALOADER FACTORY ====================
|
| 138 |
+
def create_dataloaders(
|
| 139 |
+
fer2013_dir: str = None,
|
| 140 |
+
ravdess_dir: str = None,
|
| 141 |
+
batch_size: int = 32,
|
| 142 |
+
num_workers: int = 0,
|
| 143 |
+
img_size: int = 224
|
| 144 |
+
) -> dict:
|
| 145 |
+
"""
|
| 146 |
+
Create dataloaders for FER2013 and RAVDESS datasets.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
fer2013_dir: Path to FER2013 dataset root
|
| 150 |
+
ravdess_dir: Path to RAVDESS dataset root
|
| 151 |
+
batch_size: Batch size for training
|
| 152 |
+
num_workers: Number of workers for data loading
|
| 153 |
+
img_size: Image size for FER2013
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Dictionary with dataloaders for each dataset
|
| 157 |
+
"""
|
| 158 |
+
transform = transforms.Compose([
|
| 159 |
+
transforms.ToPILImage(),
|
| 160 |
+
transforms.Resize((img_size, img_size)),
|
| 161 |
+
transforms.ToTensor(),
|
| 162 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 163 |
+
std=[0.229, 0.224, 0.225])
|
| 164 |
+
])
|
| 165 |
+
|
| 166 |
+
dataloaders = {}
|
| 167 |
+
|
| 168 |
+
if fer2013_dir and os.path.exists(fer2013_dir):
|
| 169 |
+
train_dataset = FER2013Dataset(fer2013_dir, split='train', transform=transform)
|
| 170 |
+
test_dataset = FER2013Dataset(fer2013_dir, split='test', transform=transform)
|
| 171 |
+
|
| 172 |
+
dataloaders['fer2013_train'] = DataLoader(
|
| 173 |
+
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
| 174 |
+
)
|
| 175 |
+
dataloaders['fer2013_test'] = DataLoader(
|
| 176 |
+
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if ravdess_dir and os.path.exists(ravdess_dir):
|
| 180 |
+
audio_dataset = RAVDESSDataset(ravdess_dir)
|
| 181 |
+
dataloaders['ravdess'] = DataLoader(
|
| 182 |
+
audio_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return dataloaders
|
backend/services/explainability.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Explainability utilities for multimodal emotion recognition outputs."""
|
| 2 |
+
import os
|
| 3 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "0"
|
| 4 |
+
os.environ["QT_QPA_PLATFORM"] = "offscreen"
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import librosa
|
| 10 |
+
import librosa.display
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from io import BytesIO
|
| 16 |
+
import base64
|
| 17 |
+
|
| 18 |
+
from pytorch_grad_cam import GradCAM, EigenCAM
|
| 19 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 20 |
+
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ==================== MODEL WRAPPER ====================
|
| 24 |
+
class ViTLogitsWrapper(nn.Module):
|
| 25 |
+
def __init__(self, model):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.model = model
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
# Grad-CAM expects a standard forward() that returns logits for the selected class.
|
| 31 |
+
return self.model(pixel_values=x).logits
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ==================== FACIAL EXPLAINABILITY ====================
|
| 35 |
+
def generate_grad_cam(image, model, processor, emotion_idx, emotions_list, device):
|
| 36 |
+
try:
|
| 37 |
+
img_rgb = np.array(image.convert('RGB'))
|
| 38 |
+
h, w = img_rgb.shape[:2]
|
| 39 |
+
img_pil = Image.fromarray(img_rgb)
|
| 40 |
+
|
| 41 |
+
inputs = processor(img_pil, return_tensors='pt').to(device)
|
| 42 |
+
input_tensor = inputs['pixel_values']
|
| 43 |
+
|
| 44 |
+
wrapped_model = ViTLogitsWrapper(model)
|
| 45 |
+
wrapped_model.eval()
|
| 46 |
+
|
| 47 |
+
# Try multiple layers because the last block can become too saturated for a usable heatmap.
|
| 48 |
+
layers_to_try = [
|
| 49 |
+
model.vit.encoder.layer[-1].layernorm_after,
|
| 50 |
+
model.vit.encoder.layer[-2].layernorm_after,
|
| 51 |
+
model.vit.encoder.layer[-3].layernorm_after,
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
cam_map = None
|
| 55 |
+
method_used = None
|
| 56 |
+
|
| 57 |
+
for i, layer in enumerate(layers_to_try):
|
| 58 |
+
try:
|
| 59 |
+
cam = GradCAM(
|
| 60 |
+
model=wrapped_model,
|
| 61 |
+
target_layers=[layer],
|
| 62 |
+
reshape_transform=vit_reshape_transform,
|
| 63 |
+
)
|
| 64 |
+
targets = [ClassifierOutputTarget(emotion_idx)]
|
| 65 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
| 66 |
+
result = grayscale_cam[0]
|
| 67 |
+
|
| 68 |
+
# Reject degenerate maps so the UI never shows a blank explanation as if it were valid.
|
| 69 |
+
if result.max() > 0.01:
|
| 70 |
+
cam_map = result
|
| 71 |
+
method_used = f"GradCAM (encoder block {12 - (i+1)})"
|
| 72 |
+
break
|
| 73 |
+
else:
|
| 74 |
+
print(f"[explainability] layer[-{i+1}] all zeros, trying next")
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"[explainability] GradCAM layer[-{i+1}] failed: {e}")
|
| 78 |
+
|
| 79 |
+
# Final fallback: EigenCAM gives a stable PCA-based map when gradients are unhelpful.
|
| 80 |
+
if cam_map is None:
|
| 81 |
+
print("[explainability] All GradCAM layers zero, using EigenCAM")
|
| 82 |
+
try:
|
| 83 |
+
eigen = EigenCAM(
|
| 84 |
+
model=wrapped_model,
|
| 85 |
+
target_layers=[model.vit.encoder.layer[-1].layernorm_after],
|
| 86 |
+
reshape_transform=vit_reshape_transform,
|
| 87 |
+
)
|
| 88 |
+
grayscale_cam = eigen(input_tensor=input_tensor)
|
| 89 |
+
cam_map = grayscale_cam[0]
|
| 90 |
+
method_used = "EigenCAM"
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"[explainability] EigenCAM failed: {e}")
|
| 93 |
+
return None, None
|
| 94 |
+
|
| 95 |
+
print(f"[explainability] {method_used} — min={cam_map.min():.3f}, max={cam_map.max():.3f}")
|
| 96 |
+
|
| 97 |
+
# Upscale and smooth the heatmap so it overlays cleanly on the source image.
|
| 98 |
+
cam_resized = cv2.resize(cam_map.astype(np.float32), (w, h), interpolation=cv2.INTER_CUBIC)
|
| 99 |
+
cam_resized = cv2.GaussianBlur(cam_resized, (13, 13), 0)
|
| 100 |
+
|
| 101 |
+
c_min, c_max = cam_resized.min(), cam_resized.max()
|
| 102 |
+
if c_max > c_min:
|
| 103 |
+
cam_resized = (cam_resized - c_min) / (c_max - c_min)
|
| 104 |
+
|
| 105 |
+
# Build the colored overlay and blend only the most salient regions.
|
| 106 |
+
cam_uint8 = np.uint8(255 * cam_resized)
|
| 107 |
+
heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
|
| 108 |
+
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
|
| 109 |
+
|
| 110 |
+
threshold = np.percentile(cam_resized, 70)
|
| 111 |
+
blend_mask = (cam_resized > threshold).astype(np.float32)
|
| 112 |
+
blend_mask = cv2.GaussianBlur(blend_mask, (31, 31), 0)[..., None]
|
| 113 |
+
|
| 114 |
+
blended = (
|
| 115 |
+
(1 - blend_mask * 0.65) * img_rgb.astype(np.float32)
|
| 116 |
+
+ blend_mask * 0.65 * heatmap_rgb.astype(np.float32)
|
| 117 |
+
).clip(0, 255).astype(np.uint8)
|
| 118 |
+
|
| 119 |
+
orig_buf = BytesIO()
|
| 120 |
+
Image.fromarray(img_rgb).save(orig_buf, format='PNG')
|
| 121 |
+
orig_b64 = base64.b64encode(orig_buf.getvalue()).decode()
|
| 122 |
+
|
| 123 |
+
blend_buf = BytesIO()
|
| 124 |
+
Image.fromarray(blended).save(blend_buf, format='PNG')
|
| 125 |
+
blend_b64 = base64.b64encode(blend_buf.getvalue()).decode()
|
| 126 |
+
|
| 127 |
+
return orig_b64, blend_b64
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"[explainability] GradCAM generation failed: {e}")
|
| 131 |
+
return None, None
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ==================== AUDIO EXPLAINABILITY ====================
|
| 135 |
+
def generate_audio_saliency(audio, model, processor, emotion_idx, emotions_list, device, sr=16000):
|
| 136 |
+
try:
|
| 137 |
+
if audio is None or len(audio) == 0:
|
| 138 |
+
raise ValueError("Audio input is empty")
|
| 139 |
+
|
| 140 |
+
# Sanitize the audio before passing it into the speech backbone.
|
| 141 |
+
audio = np.asarray(audio, dtype=np.float32)
|
| 142 |
+
audio = np.nan_to_num(audio)
|
| 143 |
+
|
| 144 |
+
inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
|
| 145 |
+
input_values = inputs['input_values'].to(device)
|
| 146 |
+
|
| 147 |
+
input_values.requires_grad = True
|
| 148 |
+
model.zero_grad()
|
| 149 |
+
|
| 150 |
+
outputs = model(input_values)
|
| 151 |
+
score = outputs.logits[0, emotion_idx]
|
| 152 |
+
score.backward()
|
| 153 |
+
|
| 154 |
+
if input_values.grad is None:
|
| 155 |
+
raise RuntimeError("No gradients captured")
|
| 156 |
+
|
| 157 |
+
saliency = torch.abs(input_values.grad).cpu().detach().numpy()
|
| 158 |
+
|
| 159 |
+
if saliency.ndim == 3:
|
| 160 |
+
saliency = np.mean(saliency, axis=1)[0]
|
| 161 |
+
elif saliency.ndim == 2:
|
| 162 |
+
saliency = saliency[0]
|
| 163 |
+
saliency = saliency.reshape(-1).astype(np.float32)
|
| 164 |
+
|
| 165 |
+
if saliency.size > 11:
|
| 166 |
+
# Smooth the gradient spikes so the curve is readable in the plot.
|
| 167 |
+
kernel = np.ones(11, dtype=np.float32) / 11.0
|
| 168 |
+
saliency = np.convolve(saliency, kernel, mode='same')
|
| 169 |
+
|
| 170 |
+
s_min, s_max = saliency.min(), saliency.max()
|
| 171 |
+
if s_max > s_min:
|
| 172 |
+
saliency = (saliency - s_min) / (s_max - s_min)
|
| 173 |
+
else:
|
| 174 |
+
saliency = np.zeros_like(saliency)
|
| 175 |
+
|
| 176 |
+
# Build both the spectrogram and the saliency overlay for a side-by-side explanation.
|
| 177 |
+
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
|
| 178 |
+
S_db = librosa.power_to_db(S, ref=np.max)
|
| 179 |
+
|
| 180 |
+
fig1, ax1 = plt.subplots(figsize=(10, 4), dpi=100)
|
| 181 |
+
img1 = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax1)
|
| 182 |
+
ax1.set_title(f'Audio Spectrogram — {emotions_list[emotion_idx]}')
|
| 183 |
+
fig1.colorbar(img1, ax=ax1, format='%+2.0f dB')
|
| 184 |
+
spec_buf = BytesIO()
|
| 185 |
+
fig1.savefig(spec_buf, format='PNG', bbox_inches='tight', dpi=100)
|
| 186 |
+
spec_b64 = base64.b64encode(spec_buf.getvalue()).decode()
|
| 187 |
+
plt.close(fig1)
|
| 188 |
+
|
| 189 |
+
fig2, (ax2, ax3) = plt.subplots(2, 1, figsize=(10, 5.5), dpi=100,
|
| 190 |
+
gridspec_kw={'height_ratios': [3, 1]}, sharex=False)
|
| 191 |
+
|
| 192 |
+
# Normalize the spectrogram so the saliency colors stay visible across different recordings.
|
| 193 |
+
S_norm = (S_db - S_db.min()) / max(S_db.max() - S_db.min(), 1e-8)
|
| 194 |
+
sal_resized = np.interp(np.linspace(0, 1, S_db.shape[1]),
|
| 195 |
+
np.linspace(0, 1, saliency.shape[0]), saliency)
|
| 196 |
+
sal_map = np.tile(sal_resized, (S_db.shape[0], 1))
|
| 197 |
+
|
| 198 |
+
ax2.imshow(S_norm, aspect='auto', origin='lower', cmap='viridis', interpolation='bilinear')
|
| 199 |
+
ax2.imshow(sal_map, aspect='auto', origin='lower', cmap='magma', alpha=0.6, interpolation='bilinear')
|
| 200 |
+
ax2.set_title(f'Audio Saliency — {emotions_list[emotion_idx]} (bright = important)')
|
| 201 |
+
ax2.set_ylabel('Mel Frequency')
|
| 202 |
+
|
| 203 |
+
# Highlight the strongest time steps to give the user a clear peak view.
|
| 204 |
+
peak_thr = np.percentile(sal_resized, 85)
|
| 205 |
+
x = np.arange(len(sal_resized))
|
| 206 |
+
ax3.plot(x, sal_resized, color='#f97316', linewidth=1.5)
|
| 207 |
+
ax3.fill_between(x, 0, sal_resized, where=sal_resized >= peak_thr, color='#ef4444', alpha=0.4)
|
| 208 |
+
ax3.axhline(peak_thr, color='#ef4444', linestyle='--', linewidth=1, alpha=0.8)
|
| 209 |
+
ax3.set_ylim(0, 1.05)
|
| 210 |
+
ax3.set_ylabel('Saliency')
|
| 211 |
+
ax3.set_xlabel('Time steps')
|
| 212 |
+
ax3.grid(alpha=0.2)
|
| 213 |
+
|
| 214 |
+
sal_buf = BytesIO()
|
| 215 |
+
fig2.tight_layout()
|
| 216 |
+
fig2.savefig(sal_buf, format='PNG', bbox_inches='tight', dpi=100)
|
| 217 |
+
sal_b64 = base64.b64encode(sal_buf.getvalue()).decode()
|
| 218 |
+
plt.close(fig2)
|
| 219 |
+
|
| 220 |
+
return spec_b64, sal_b64
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
print(f"[explainability] Audio saliency failed: {e}")
|
| 224 |
+
return None, None
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ==================== COMBINED VISUALIZATION ====================
|
| 228 |
+
def create_combined_visualization(grad_cam_base64, saliency_base64, facial_emotion, speech_emotion, concordance):
|
| 229 |
+
try:
|
| 230 |
+
# Use a soft status tint so the combined report communicates agreement at a glance.
|
| 231 |
+
bg_color = '#d4edda' if concordance == 'MATCH' else '#f8d7da'
|
| 232 |
+
html = f"""
|
| 233 |
+
<div style="display:flex;gap:20px;padding:20px;background:#f5f5f5;border-radius:10px;">
|
| 234 |
+
<div style="flex:1;">
|
| 235 |
+
<h3>Facial GradCAM — {facial_emotion}</h3>
|
| 236 |
+
<img src="data:image/png;base64,{grad_cam_base64}" style="width:100%;border-radius:8px;">
|
| 237 |
+
<p style="font-size:12px;color:#666;">Red/warm = regions that most influenced the {facial_emotion} prediction.</p>
|
| 238 |
+
</div>
|
| 239 |
+
<div style="flex:1;">
|
| 240 |
+
<h3>Speech Saliency — {speech_emotion}</h3>
|
| 241 |
+
<img src="data:image/png;base64,{saliency_base64}" style="width:100%;border-radius:8px;">
|
| 242 |
+
<p style="font-size:12px;color:#666;">Bright = time-frequency regions with strongest influence.</p>
|
| 243 |
+
</div>
|
| 244 |
+
</div>
|
| 245 |
+
<div style="margin-top:20px;padding:15px;background:{bg_color};border-radius:8px;text-align:center;">
|
| 246 |
+
<h4 style="margin:0;">Concordance: <strong>{concordance}</strong></h4>
|
| 247 |
+
</div>
|
| 248 |
+
"""
|
| 249 |
+
return base64.b64encode(html.encode()).decode()
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"[explainability] Combined visualisation failed: {e}")
|
| 252 |
+
return None
|
requirements.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web Framework
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn[standard]>=0.24.0
|
| 4 |
+
gunicorn>=21.2.0
|
| 5 |
+
python-multipart>=0.0.6
|
| 6 |
+
|
| 7 |
+
# Core ML Libraries
|
| 8 |
+
torch>=2.6.0
|
| 9 |
+
torchvision>=0.20.0
|
| 10 |
+
torchaudio>=2.6.0
|
| 11 |
+
transformers>=4.46.0
|
| 12 |
+
timm>=1.0.0
|
| 13 |
+
|
| 14 |
+
# Face Detection
|
| 15 |
+
facenet-pytorch>=2.5.3
|
| 16 |
+
|
| 17 |
+
# Audio Processing
|
| 18 |
+
librosa>=0.10.2
|
| 19 |
+
soundfile>=0.12.1
|
| 20 |
+
pydub>=0.25.1
|
| 21 |
+
|
| 22 |
+
# Computer Vision
|
| 23 |
+
opencv-python-headless>=4.10.0
|
| 24 |
+
pillow>=10.4.0
|
| 25 |
+
|
| 26 |
+
# Deep Learning
|
| 27 |
+
numpy>=1.26.0
|
| 28 |
+
scipy>=1.14.0
|
| 29 |
+
scikit-learn>=1.5.0
|
| 30 |
+
|
| 31 |
+
# Visualization
|
| 32 |
+
matplotlib>=3.9.0
|
| 33 |
+
seaborn>=0.13.2
|
| 34 |
+
|
| 35 |
+
# Configuration
|
| 36 |
+
python-dotenv>=1.0.0
|
| 37 |
+
|
| 38 |
+
# Database
|
| 39 |
+
supabase>=2.0.0
|
| 40 |
+
|
| 41 |
+
# HuggingFace
|
| 42 |
+
huggingface_hub>=0.24.0
|
| 43 |
+
grad-cam
|
| 44 |
+
|
start.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
exec python -m gunicorn backend.main:app -w 1 -k uvicorn.workers.UvicornWorker --timeout 600 --bind 0.0.0.0:${PORT:-8080}
|