import torch import numpy as np import librosa import io import time from datetime import datetime from typing import Optional, List, Dict from fastapi import FastAPI, UploadFile, File, HTTPException, Query from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor import base64 import os # -------------------- CONFIG -------------------- MODEL_ID = os.getenv("HF_MODEL_ID", "abedir/emotion-detector") label_map = { 0: "Angry/Fearful", 1: "Happy/Laugh", 2: "Neutral/Calm", 3: "Sad/Cry", 4: "Surprised/Amazed" } MAX_DURATION = 3.0 # seconds API_VERSION = "1.0.0" # -------------------- LOAD MODEL FROM HF HUB -------------------- device = "cuda" if torch.cuda.is_available() else "cpu" print("=" * 60) print("HuBERT Emotion Recognition API - Starting") print("=" * 60) print(f"Device: {device}") print(f"Loading model from Hugging Face Hub: {MODEL_ID}") try: processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID) model = AutoModelForAudioClassification.from_pretrained(MODEL_ID) model.to(device) model.eval() print("✓ Model loaded successfully from Hugging Face Hub") print("=" * 60) except Exception as e: print("=" * 60) print("✗ ERROR: Failed to load model from Hugging Face Hub") print("=" * 60) print(f"\nError details: {e}\n") print("Please ensure:") print(f"1. Model ID is correct: {MODEL_ID}") print("2. Model repository exists and is accessible") print("3. Model contains all required files:") print(" - config.json") print(" - preprocessor_config.json") print(" - model.safetensors") print("=" * 60) raise sampling_rate = processor.sampling_rate max_length = int(MAX_DURATION * sampling_rate) # -------------------- PYDANTIC MODELS -------------------- class EmotionPrediction(BaseModel): emotion: str = Field(..., description="Predicted emotion label") confidence: float = Field(..., description="Confidence score (0-1)") probabilities: Dict[str, float] = Field(..., description="Probability distribution across all emotions") class BatchPredictionResponse(BaseModel): predictions: List[EmotionPrediction] total_files: int processing_time_seconds: float class HealthResponse(BaseModel): status: str model_loaded: bool device: str supported_emotions: List[str] api_version: str model_id: str timestamp: str class ModelInfoResponse(BaseModel): model_name: str model_id: str model_type: str num_labels: int emotion_labels: Dict[int, str] sample_rate: int max_duration_seconds: float device: str class AudioInfoResponse(BaseModel): duration_seconds: float sample_rate: int num_samples: int is_truncated: bool is_padded: bool class Base64PredictionRequest(BaseModel): audio_base64: str = Field(..., description="Base64 encoded audio file") filename: Optional[str] = Field(None, description="Original filename for reference") class ErrorResponse(BaseModel): error: str detail: str timestamp: str # -------------------- FASTAPI APP -------------------- app = FastAPI( title="HuBERT Emotion Recognition API", description="Advanced emotion recognition API using HuBERT model - Model: abedir/emotion-detector", version=API_VERSION, docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # -------------------- HELPER FUNCTIONS -------------------- def get_audio_info(audio: np.ndarray, sr: int) -> AudioInfoResponse: """Get information about the audio""" duration = len(audio) / sr is_truncated = duration > MAX_DURATION is_padded = duration < MAX_DURATION return AudioInfoResponse( duration_seconds=float(duration), sample_rate=sr, num_samples=len(audio), is_truncated=is_truncated, is_padded=is_padded ) def preprocess_audio(file_bytes: bytes) -> tuple[torch.Tensor, AudioInfoResponse]: """Preprocess audio bytes for model input and return audio info""" try: audio, sr = librosa.load( io.BytesIO(file_bytes), sr=sampling_rate ) audio_info = get_audio_info(audio, sr) # Truncate or pad to max_length if len(audio) > max_length: audio = audio[:max_length] else: audio = np.pad(audio, (0, max_length - len(audio))) inputs = processor( audio, sampling_rate=sampling_rate, return_tensors="pt" ) return inputs.input_values.to(device), audio_info except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing audio: {str(e)}") def predict_emotion(input_values: torch.Tensor) -> EmotionPrediction: """Run emotion prediction""" with torch.no_grad(): outputs = model(input_values) probs = torch.softmax(outputs.logits, dim=1)[0] pred_id = torch.argmax(probs).item() return EmotionPrediction( emotion=label_map[pred_id], confidence=float(probs[pred_id]), probabilities={ label_map[i]: float(probs[i]) for i in range(len(label_map)) } ) # -------------------- ENDPOINTS -------------------- @app.get("/", response_model=Dict[str, str]) async def root(): """Root endpoint - API welcome message""" return { "message": "HuBERT Emotion Recognition API", "version": API_VERSION, "model_id": MODEL_ID, "docs": "/docs", "health": "/health" } @app.get("/health", response_model=HealthResponse) async def health(): """Comprehensive health check endpoint""" return HealthResponse( status="healthy", model_loaded=model is not None, device=device, supported_emotions=list(label_map.values()), api_version=API_VERSION, model_id=MODEL_ID, timestamp=datetime.now().isoformat() ) @app.get("/model/info", response_model=ModelInfoResponse) async def get_model_info(): """Get detailed model information""" return ModelInfoResponse( model_name="HuBERT Emotion Detector", model_id=MODEL_ID, model_type="Audio Classification - Emotion Recognition", num_labels=len(label_map), emotion_labels=label_map, sample_rate=sampling_rate, max_duration_seconds=MAX_DURATION, device=device ) @app.get("/emotions", response_model=Dict[str, List[str]]) async def list_emotions(): """List all supported emotion labels""" return { "emotions": list(label_map.values()), "count": len(label_map) } @app.post("/predict", response_model=EmotionPrediction) async def predict( file: UploadFile = File(..., description="Audio file (.wav, .mp3, .flac, .ogg)") ): """Predict emotion from uploaded audio file""" if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): raise HTTPException( status_code=400, detail="Invalid file format. Supported: .wav, .mp3, .flac, .ogg, .m4a, .webm" ) audio_bytes = await file.read() input_values, _ = preprocess_audio(audio_bytes) return predict_emotion(input_values) @app.post("/predict/detailed", response_model=Dict) async def predict_detailed( file: UploadFile = File(..., description="Audio file") ): """Predict emotion with detailed audio information""" if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): raise HTTPException( status_code=400, detail="Invalid file format. Supported: .wav, .mp3, .flac, .ogg, .m4a, .webm" ) audio_bytes = await file.read() input_values, audio_info = preprocess_audio(audio_bytes) prediction = predict_emotion(input_values) return { "prediction": prediction.dict(), "audio_info": audio_info.dict(), "filename": file.filename, "timestamp": datetime.now().isoformat() } @app.post("/predict/base64", response_model=EmotionPrediction) async def predict_base64(request: Base64PredictionRequest): """Predict emotion from base64 encoded audio""" try: audio_bytes = base64.b64decode(request.audio_base64) except Exception as e: raise HTTPException( status_code=400, detail=f"Invalid base64 encoding: {str(e)}" ) input_values, _ = preprocess_audio(audio_bytes) return predict_emotion(input_values) @app.post("/predict/batch", response_model=BatchPredictionResponse) async def predict_batch( files: List[UploadFile] = File(..., description="Multiple audio files") ): """Batch prediction for multiple audio files""" if len(files) > 50: raise HTTPException( status_code=400, detail="Maximum 50 files per batch request" ) start_time = time.time() predictions = [] for file in files: if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): continue try: audio_bytes = await file.read() input_values, _ = preprocess_audio(audio_bytes) prediction = predict_emotion(input_values) predictions.append(prediction) except Exception as e: print(f"Error processing {file.filename}: {e}") continue processing_time = time.time() - start_time return BatchPredictionResponse( predictions=predictions, total_files=len(predictions), processing_time_seconds=processing_time ) @app.post("/analyze/audio", response_model=AudioInfoResponse) async def analyze_audio( file: UploadFile = File(..., description="Audio file to analyze") ): """Analyze audio file and return metadata without prediction""" try: audio_bytes = await file.read() audio, sr = librosa.load(io.BytesIO(audio_bytes), sr=sampling_rate) return get_audio_info(audio, sr) except Exception as e: raise HTTPException( status_code=400, detail=f"Error analyzing audio: {str(e)}" ) @app.post("/predict/top-k") async def predict_top_k( file: UploadFile = File(...), k: int = Query(3, ge=1, le=5, description="Number of top predictions to return") ): """Get top-k emotion predictions""" if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): raise HTTPException(status_code=400, detail="Invalid file format") audio_bytes = await file.read() input_values, _ = preprocess_audio(audio_bytes) with torch.no_grad(): outputs = model(input_values) probs = torch.softmax(outputs.logits, dim=1)[0] top_k_probs, top_k_indices = torch.topk(probs, k) top_predictions = [ { "rank": i + 1, "emotion": label_map[idx.item()], "confidence": prob.item() } for i, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices)) ] return { "top_predictions": top_predictions, "total_emotions": len(label_map) } @app.post("/predict/threshold") async def predict_with_threshold( file: UploadFile = File(...), threshold: float = Query(0.5, ge=0.0, le=1.0, description="Confidence threshold") ): """Predict emotion only if confidence exceeds threshold""" if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): raise HTTPException(status_code=400, detail="Invalid file format") audio_bytes = await file.read() input_values, _ = preprocess_audio(audio_bytes) prediction = predict_emotion(input_values) if prediction.confidence >= threshold: return { "status": "confident", "prediction": prediction.dict() } else: return { "status": "uncertain", "message": f"Confidence {prediction.confidence:.3f} below threshold {threshold}", "best_guess": prediction.dict() } @app.get("/stats") async def get_stats(): """Get API statistics and system information""" return { "model": { "name": "HuBERT Emotion Detector", "model_id": MODEL_ID, "device": device, "loaded": model is not None }, "configuration": { "max_duration_seconds": MAX_DURATION, "sample_rate": sampling_rate, "num_emotions": len(label_map) }, "system": { "cuda_available": torch.cuda.is_available(), "torch_version": torch.__version__ }, "api_version": API_VERSION, "timestamp": datetime.now().isoformat() } @app.get("/version") async def get_version(): """Get API version information""" return { "api_version": API_VERSION, "framework": "FastAPI", "model": "HuBERT Emotion Detector", "model_id": MODEL_ID, "timestamp": datetime.now().isoformat() } # -------------------- ERROR HANDLERS -------------------- @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): """Custom HTTP exception handler""" return JSONResponse( status_code=exc.status_code, content=ErrorResponse( error=exc.detail, detail=str(exc), timestamp=datetime.now().isoformat() ).dict() ) @app.exception_handler(Exception) async def general_exception_handler(request, exc): """General exception handler""" return JSONResponse( status_code=500, content=ErrorResponse( error="Internal server error", detail=str(exc), timestamp=datetime.now().isoformat() ).dict() ) # -------------------- STARTUP EVENT -------------------- @app.on_event("startup") async def startup_event(): """Log startup information""" print("API is ready to accept requests!") print(f"Visit /docs for interactive API documentation")