EMOTIA / backend /main.py
Manav2op's picture
Upload folder using huggingface_hub
25d0747 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
import cv2
import numpy as np
import io
from PIL import Image
import librosa
import asyncio
from typing import List, Dict, Optional
import time
import logging
import sys
import os
# Add parent directory to path for model imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from models.vision import VisionEmotionModel
from models.audio import AudioEmotionModel
from models.text import TextIntentModel
from models.fusion import MultiModalFusion
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="EMOTIA API", description="Multi-Modal Emotion & Intent Intelligence API")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify allowed origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instances
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Initialize models (load from checkpoints in production)
vision_model = VisionEmotionModel().to(device)
audio_model = AudioEmotionModel().to(device)
text_model = TextIntentModel().to(device)
fusion_model = MultiModalFusion().to(device)
# Load trained weights (placeholder)
# vision_model.load_state_dict(torch.load('models/checkpoints/vision.pth'))
# audio_model.load_state_dict(torch.load('models/checkpoints/audio.pth'))
# text_model.load_state_dict(torch.load('models/checkpoints/text.pth'))
# fusion_model.load_state_dict(torch.load('models/checkpoints/fusion.pth'))
vision_model.eval()
audio_model.eval()
text_model.eval()
fusion_model.eval()
emotion_labels = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
intent_labels = ['agreement', 'confusion', 'hesitation', 'confidence', 'neutral']
@app.get("/")
async def root():
return {"message": "EMOTIA Multi-Modal Emotion & Intent Intelligence API"}
@app.post("/analyze/frame")
async def analyze_frame(
image: UploadFile = File(...),
audio: Optional[UploadFile] = File(None),
text: Optional[str] = None
):
"""
Analyze a single frame with optional audio and text.
Returns emotion, intent, engagement, confidence, and modality contributions.
"""
start_time = time.time()
try:
# Process image
image_data = await image.read()
image_pil = Image.open(io.BytesIO(image_data))
image_np = np.array(image_pil)
# Detect faces and extract features
faces = vision_model.detect_faces(image_np)
if not faces:
raise HTTPException(status_code=400, detail="No faces detected in image")
vision_logits, vision_conf = vision_model.extract_features(faces)
vision_features = vision_model.vit(pixel_values=torch.stack([
vision_model.transform(face) for face in faces
]).to(device)).last_hidden_state[:, 0, :].mean(dim=0) # Average across faces
# Process audio if provided
audio_features = None
if audio:
audio_data = await audio.read()
audio_np, _ = librosa.load(io.BytesIO(audio_data), sr=16000, duration=3.0)
audio_tensor = torch.tensor(audio_np, dtype=torch.float32).to(device)
audio_logits, audio_stress = audio_model(audio_tensor.unsqueeze(0))
audio_features = audio_model.wav2vec(audio_tensor.unsqueeze(0)).last_hidden_state.mean(dim=1)
# Process text if provided
text_features = None
if text:
input_ids, attention_mask = text_model.preprocess_text(text)
input_ids = input_ids.to(device).unsqueeze(0)
attention_mask = attention_mask.to(device).unsqueeze(0)
intent_logits, sentiment_logits, text_conf = text_model(input_ids, attention_mask)
text_features = text_model.bert(input_ids, attention_mask).pooler_output
# Default features if modality missing
if audio_features is None:
audio_features = torch.zeros(1, 128).to(device)
if text_features is None:
text_features = torch.zeros(1, 768).to(device)
# Fuse modalities
with torch.no_grad():
results = fusion_model(
vision_features.unsqueeze(0),
audio_features,
text_features
)
# Convert to readable format
emotion_probs = torch.softmax(results['emotion'], dim=1)[0].cpu().numpy()
intent_probs = torch.softmax(results['intent'], dim=1)[0].cpu().numpy()
response = {
"emotion": {
"predictions": {emotion_labels[i]: float(prob) for i, prob in enumerate(emotion_probs)},
"dominant": emotion_labels[np.argmax(emotion_probs)]
},
"intent": {
"predictions": {intent_labels[i]: float(prob) for i, prob in enumerate(intent_probs)},
"dominant": intent_labels[np.argmax(intent_probs)]
},
"engagement": float(results['engagement'].cpu().numpy()),
"confidence": float(results['confidence'].cpu().numpy()),
"modality_contributions": {
"vision": float(results['contributions'][0].cpu().numpy()),
"audio": float(results['contributions'][1].cpu().numpy()),
"text": float(results['contributions'][2].cpu().numpy())
},
"processing_time": time.time() - start_time
}
return response
except Exception as e:
logger.error(f"Error processing frame: {str(e)}")
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/analyze/stream")
async def analyze_stream(data: Dict):
"""
Analyze streaming video/audio/text data.
Expects base64 encoded frames and audio chunks.
"""
# Placeholder for streaming analysis
# In production, this would handle WebRTC streams
return {"message": "Streaming analysis not yet implemented"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "device": str(device)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)