Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import librosa | |
| import io | |
| import time | |
| from datetime import datetime | |
| from typing import Optional, List, Dict | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Query | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor | |
| import base64 | |
| import os | |
| # -------------------- CONFIG -------------------- | |
| MODEL_ID = os.getenv("HF_MODEL_ID", "abedir/emotion-detector") | |
| label_map = { | |
| 0: "Angry/Fearful", | |
| 1: "Happy/Laugh", | |
| 2: "Neutral/Calm", | |
| 3: "Sad/Cry", | |
| 4: "Surprised/Amazed" | |
| } | |
| MAX_DURATION = 3.0 # seconds | |
| API_VERSION = "1.0.0" | |
| # -------------------- LOAD MODEL FROM HF HUB -------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("=" * 60) | |
| print("HuBERT Emotion Recognition API - Starting") | |
| print("=" * 60) | |
| print(f"Device: {device}") | |
| print(f"Loading model from Hugging Face Hub: {MODEL_ID}") | |
| try: | |
| processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID) | |
| model = AutoModelForAudioClassification.from_pretrained(MODEL_ID) | |
| model.to(device) | |
| model.eval() | |
| print("✓ Model loaded successfully from Hugging Face Hub") | |
| print("=" * 60) | |
| except Exception as e: | |
| print("=" * 60) | |
| print("✗ ERROR: Failed to load model from Hugging Face Hub") | |
| print("=" * 60) | |
| print(f"\nError details: {e}\n") | |
| print("Please ensure:") | |
| print(f"1. Model ID is correct: {MODEL_ID}") | |
| print("2. Model repository exists and is accessible") | |
| print("3. Model contains all required files:") | |
| print(" - config.json") | |
| print(" - preprocessor_config.json") | |
| print(" - model.safetensors") | |
| print("=" * 60) | |
| raise | |
| sampling_rate = processor.sampling_rate | |
| max_length = int(MAX_DURATION * sampling_rate) | |
| # -------------------- PYDANTIC MODELS -------------------- | |
| class EmotionPrediction(BaseModel): | |
| emotion: str = Field(..., description="Predicted emotion label") | |
| confidence: float = Field(..., description="Confidence score (0-1)") | |
| probabilities: Dict[str, float] = Field(..., description="Probability distribution across all emotions") | |
| class BatchPredictionResponse(BaseModel): | |
| predictions: List[EmotionPrediction] | |
| total_files: int | |
| processing_time_seconds: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| supported_emotions: List[str] | |
| api_version: str | |
| model_id: str | |
| timestamp: str | |
| class ModelInfoResponse(BaseModel): | |
| model_name: str | |
| model_id: str | |
| model_type: str | |
| num_labels: int | |
| emotion_labels: Dict[int, str] | |
| sample_rate: int | |
| max_duration_seconds: float | |
| device: str | |
| class AudioInfoResponse(BaseModel): | |
| duration_seconds: float | |
| sample_rate: int | |
| num_samples: int | |
| is_truncated: bool | |
| is_padded: bool | |
| class Base64PredictionRequest(BaseModel): | |
| audio_base64: str = Field(..., description="Base64 encoded audio file") | |
| filename: Optional[str] = Field(None, description="Original filename for reference") | |
| class ErrorResponse(BaseModel): | |
| error: str | |
| detail: str | |
| timestamp: str | |
| # -------------------- FASTAPI APP -------------------- | |
| app = FastAPI( | |
| title="HuBERT Emotion Recognition API", | |
| description="Advanced emotion recognition API using HuBERT model - Model: abedir/emotion-detector", | |
| version=API_VERSION, | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # -------------------- HELPER FUNCTIONS -------------------- | |
| def get_audio_info(audio: np.ndarray, sr: int) -> AudioInfoResponse: | |
| """Get information about the audio""" | |
| duration = len(audio) / sr | |
| is_truncated = duration > MAX_DURATION | |
| is_padded = duration < MAX_DURATION | |
| return AudioInfoResponse( | |
| duration_seconds=float(duration), | |
| sample_rate=sr, | |
| num_samples=len(audio), | |
| is_truncated=is_truncated, | |
| is_padded=is_padded | |
| ) | |
| def preprocess_audio(file_bytes: bytes) -> tuple[torch.Tensor, AudioInfoResponse]: | |
| """Preprocess audio bytes for model input and return audio info""" | |
| try: | |
| audio, sr = librosa.load( | |
| io.BytesIO(file_bytes), | |
| sr=sampling_rate | |
| ) | |
| audio_info = get_audio_info(audio, sr) | |
| # Truncate or pad to max_length | |
| if len(audio) > max_length: | |
| audio = audio[:max_length] | |
| else: | |
| audio = np.pad(audio, (0, max_length - len(audio))) | |
| inputs = processor( | |
| audio, | |
| sampling_rate=sampling_rate, | |
| return_tensors="pt" | |
| ) | |
| return inputs.input_values.to(device), audio_info | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error processing audio: {str(e)}") | |
| def predict_emotion(input_values: torch.Tensor) -> EmotionPrediction: | |
| """Run emotion prediction""" | |
| with torch.no_grad(): | |
| outputs = model(input_values) | |
| probs = torch.softmax(outputs.logits, dim=1)[0] | |
| pred_id = torch.argmax(probs).item() | |
| return EmotionPrediction( | |
| emotion=label_map[pred_id], | |
| confidence=float(probs[pred_id]), | |
| probabilities={ | |
| label_map[i]: float(probs[i]) | |
| for i in range(len(label_map)) | |
| } | |
| ) | |
| # -------------------- ENDPOINTS -------------------- | |
| async def root(): | |
| """Root endpoint - API welcome message""" | |
| return { | |
| "message": "HuBERT Emotion Recognition API", | |
| "version": API_VERSION, | |
| "model_id": MODEL_ID, | |
| "docs": "/docs", | |
| "health": "/health" | |
| } | |
| async def health(): | |
| """Comprehensive health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=model is not None, | |
| device=device, | |
| supported_emotions=list(label_map.values()), | |
| api_version=API_VERSION, | |
| model_id=MODEL_ID, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| async def get_model_info(): | |
| """Get detailed model information""" | |
| return ModelInfoResponse( | |
| model_name="HuBERT Emotion Detector", | |
| model_id=MODEL_ID, | |
| model_type="Audio Classification - Emotion Recognition", | |
| num_labels=len(label_map), | |
| emotion_labels=label_map, | |
| sample_rate=sampling_rate, | |
| max_duration_seconds=MAX_DURATION, | |
| device=device | |
| ) | |
| async def list_emotions(): | |
| """List all supported emotion labels""" | |
| return { | |
| "emotions": list(label_map.values()), | |
| "count": len(label_map) | |
| } | |
| async def predict( | |
| file: UploadFile = File(..., description="Audio file (.wav, .mp3, .flac, .ogg)") | |
| ): | |
| """Predict emotion from uploaded audio file""" | |
| if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid file format. Supported: .wav, .mp3, .flac, .ogg, .m4a, .webm" | |
| ) | |
| audio_bytes = await file.read() | |
| input_values, _ = preprocess_audio(audio_bytes) | |
| return predict_emotion(input_values) | |
| async def predict_detailed( | |
| file: UploadFile = File(..., description="Audio file") | |
| ): | |
| """Predict emotion with detailed audio information""" | |
| if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid file format. Supported: .wav, .mp3, .flac, .ogg, .m4a, .webm" | |
| ) | |
| audio_bytes = await file.read() | |
| input_values, audio_info = preprocess_audio(audio_bytes) | |
| prediction = predict_emotion(input_values) | |
| return { | |
| "prediction": prediction.dict(), | |
| "audio_info": audio_info.dict(), | |
| "filename": file.filename, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def predict_base64(request: Base64PredictionRequest): | |
| """Predict emotion from base64 encoded audio""" | |
| try: | |
| audio_bytes = base64.b64decode(request.audio_base64) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid base64 encoding: {str(e)}" | |
| ) | |
| input_values, _ = preprocess_audio(audio_bytes) | |
| return predict_emotion(input_values) | |
| async def predict_batch( | |
| files: List[UploadFile] = File(..., description="Multiple audio files") | |
| ): | |
| """Batch prediction for multiple audio files""" | |
| if len(files) > 50: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Maximum 50 files per batch request" | |
| ) | |
| start_time = time.time() | |
| predictions = [] | |
| for file in files: | |
| if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): | |
| continue | |
| try: | |
| audio_bytes = await file.read() | |
| input_values, _ = preprocess_audio(audio_bytes) | |
| prediction = predict_emotion(input_values) | |
| predictions.append(prediction) | |
| except Exception as e: | |
| print(f"Error processing {file.filename}: {e}") | |
| continue | |
| processing_time = time.time() - start_time | |
| return BatchPredictionResponse( | |
| predictions=predictions, | |
| total_files=len(predictions), | |
| processing_time_seconds=processing_time | |
| ) | |
| async def analyze_audio( | |
| file: UploadFile = File(..., description="Audio file to analyze") | |
| ): | |
| """Analyze audio file and return metadata without prediction""" | |
| try: | |
| audio_bytes = await file.read() | |
| audio, sr = librosa.load(io.BytesIO(audio_bytes), sr=sampling_rate) | |
| return get_audio_info(audio, sr) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Error analyzing audio: {str(e)}" | |
| ) | |
| async def predict_top_k( | |
| file: UploadFile = File(...), | |
| k: int = Query(3, ge=1, le=5, description="Number of top predictions to return") | |
| ): | |
| """Get top-k emotion predictions""" | |
| if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): | |
| raise HTTPException(status_code=400, detail="Invalid file format") | |
| audio_bytes = await file.read() | |
| input_values, _ = preprocess_audio(audio_bytes) | |
| with torch.no_grad(): | |
| outputs = model(input_values) | |
| probs = torch.softmax(outputs.logits, dim=1)[0] | |
| top_k_probs, top_k_indices = torch.topk(probs, k) | |
| top_predictions = [ | |
| { | |
| "rank": i + 1, | |
| "emotion": label_map[idx.item()], | |
| "confidence": prob.item() | |
| } | |
| for i, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices)) | |
| ] | |
| return { | |
| "top_predictions": top_predictions, | |
| "total_emotions": len(label_map) | |
| } | |
| async def predict_with_threshold( | |
| file: UploadFile = File(...), | |
| threshold: float = Query(0.5, ge=0.0, le=1.0, description="Confidence threshold") | |
| ): | |
| """Predict emotion only if confidence exceeds threshold""" | |
| if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.webm')): | |
| raise HTTPException(status_code=400, detail="Invalid file format") | |
| audio_bytes = await file.read() | |
| input_values, _ = preprocess_audio(audio_bytes) | |
| prediction = predict_emotion(input_values) | |
| if prediction.confidence >= threshold: | |
| return { | |
| "status": "confident", | |
| "prediction": prediction.dict() | |
| } | |
| else: | |
| return { | |
| "status": "uncertain", | |
| "message": f"Confidence {prediction.confidence:.3f} below threshold {threshold}", | |
| "best_guess": prediction.dict() | |
| } | |
| async def get_stats(): | |
| """Get API statistics and system information""" | |
| return { | |
| "model": { | |
| "name": "HuBERT Emotion Detector", | |
| "model_id": MODEL_ID, | |
| "device": device, | |
| "loaded": model is not None | |
| }, | |
| "configuration": { | |
| "max_duration_seconds": MAX_DURATION, | |
| "sample_rate": sampling_rate, | |
| "num_emotions": len(label_map) | |
| }, | |
| "system": { | |
| "cuda_available": torch.cuda.is_available(), | |
| "torch_version": torch.__version__ | |
| }, | |
| "api_version": API_VERSION, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_version(): | |
| """Get API version information""" | |
| return { | |
| "api_version": API_VERSION, | |
| "framework": "FastAPI", | |
| "model": "HuBERT Emotion Detector", | |
| "model_id": MODEL_ID, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # -------------------- ERROR HANDLERS -------------------- | |
| async def http_exception_handler(request, exc): | |
| """Custom HTTP exception handler""" | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content=ErrorResponse( | |
| error=exc.detail, | |
| detail=str(exc), | |
| timestamp=datetime.now().isoformat() | |
| ).dict() | |
| ) | |
| async def general_exception_handler(request, exc): | |
| """General exception handler""" | |
| return JSONResponse( | |
| status_code=500, | |
| content=ErrorResponse( | |
| error="Internal server error", | |
| detail=str(exc), | |
| timestamp=datetime.now().isoformat() | |
| ).dict() | |
| ) | |
| # -------------------- STARTUP EVENT -------------------- | |
| async def startup_event(): | |
| """Log startup information""" | |
| print("API is ready to accept requests!") | |
| print(f"Visit /docs for interactive API documentation") | |