EMOTIA / backend /advanced /advanced_api.py
Manav2op's picture
Upload folder using huggingface_hub
25d0747 verified
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import asyncio
import json
import time
import logging
from typing import Dict, List, Optional
import torch
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import redis
import prometheus_client as prom
from prometheus_client import Counter, Histogram, Gauge
import uuid
import sys
import os
# Add parent directory to path for model imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from models.advanced.advanced_fusion import AdvancedMultiModalFusion
from models.advanced.data_augmentation import AdvancedPreprocessingPipeline
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Prometheus metrics
REQUEST_COUNT = Counter('emotia_requests_total', 'Total requests', ['endpoint', 'status'])
INFERENCE_TIME = Histogram('emotia_inference_duration_seconds', 'Inference time', ['model'])
ACTIVE_CONNECTIONS = Gauge('emotia_active_websocket_connections', 'Active WebSocket connections')
MODEL_VERSIONS = Gauge('emotia_model_versions', 'Model version info', ['version', 'accuracy'])
app = FastAPI(title="EMOTIA Advanced API", version="2.0.0")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Model registry for versioning
model_registry = {}
current_model_version = "v2.0.0"
# Redis for caching and session management
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
# Thread pool for async processing
executor = ThreadPoolExecutor(max_workers=4)
# WebSocket connection manager
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.session_data: Dict[str, Dict] = {}
async def connect(self, websocket: WebSocket, session_id: str):
await websocket.accept()
self.active_connections[session_id] = websocket
self.session_data[session_id] = {
'start_time': time.time(),
'frames_processed': 0,
'last_activity': time.time()
}
ACTIVE_CONNECTIONS.inc()
logger.info(f"WebSocket connected: {session_id}")
def disconnect(self, session_id: str):
if session_id in self.active_connections:
del self.active_connections[session_id]
del self.session_data[session_id]
ACTIVE_CONNECTIONS.dec()
logger.info(f"WebSocket disconnected: {session_id}")
async def send_personal_message(self, message: str, session_id: str):
if session_id in self.active_connections:
await self.active_connections[session_id].send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections.values():
await connection.send_text(message)
manager = ConnectionManager()
# Load models
def load_models():
"""Load and version models"""
global model_registry
# Load advanced fusion model
advanced_model = AdvancedMultiModalFusion().to(device)
# In production, load from checkpoint
# advanced_model.load_state_dict(torch.load('models/checkpoints/advanced_fusion.pth'))
advanced_model.eval()
model_registry[current_model_version] = {
'model': advanced_model,
'accuracy': 0.85, # Placeholder
'created_at': time.time(),
'preprocessing': AdvancedPreprocessingPipeline()
}
MODEL_VERSIONS.labels(version=current_model_version, accuracy=0.85).set(1)
logger.info(f"Loaded model version: {current_model_version}")
load_models()
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup"""
load_models()
@app.get("/")
async def root():
return {
"message": "EMOTIA Advanced Multi-Modal Emotion & Intent Intelligence API v2.0",
"version": current_model_version,
"endpoints": [
"/analyze/frame",
"/analyze/stream",
"/ws/analyze/{session_id}",
"/models/versions",
"/health",
"/metrics"
]
}
@app.get("/models/versions")
async def get_model_versions():
"""Get available model versions"""
versions = {}
for version, info in model_registry.items():
versions[version] = {
'accuracy': info['accuracy'],
'created_at': info['created_at']
}
return versions
@app.post("/analyze/frame")
async def analyze_frame(
image_data: bytes = None,
audio_data: bytes = None,
text: str = None,
model_version: str = current_model_version
):
"""Advanced frame analysis with caching and metrics"""
start_time = time.time()
REQUEST_COUNT.labels(endpoint='/analyze/frame', status='started').inc()
try:
# Get model
if model_version not in model_registry:
raise HTTPException(status_code=400, detail=f"Model version {model_version} not found")
model_info = model_registry[model_version]
model = model_info['model']
preprocessor = model_info['preprocessing']
# Create cache key
cache_key = f"{hash(image_data or '')}:{hash(audio_data or '')}:{hash(text or '')}:{model_version}"
cached_result = redis_client.get(cache_key)
if cached_result:
REQUEST_COUNT.labels(endpoint='/analyze/frame', status='cached').inc()
return json.loads(cached_result)
# Process inputs
vision_input = None
audio_input = None
text_input = None
if image_data:
# Decode and preprocess image
nparr = np.frombuffer(image_data, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
vision_input = preprocessor.preprocess_face(image)
if vision_input is not None:
vision_input = vision_input.unsqueeze(0).to(device)
if audio_data:
# Decode and preprocess audio
audio_np = np.frombuffer(audio_data, dtype=np.float32)
audio_input = preprocessor.preprocess_audio(audio_np)
if audio_input is not None:
audio_input = audio_input.unsqueeze(0).to(device)
if text:
# Preprocess text
text_input = preprocessor.preprocess_text(text, model.clip_tokenizer if hasattr(model, 'clip_tokenizer') else None)
text_input = {k: v.to(device) for k, v in text_input.items()}
# Run inference
with torch.no_grad():
with INFERENCE_TIME.labels(model=model_version).time():
outputs = model(
vision_input=vision_input,
audio_input=audio_input,
text_input=text_input
)
# Process outputs
result = {
'emotion': {
'probabilities': torch.softmax(outputs['emotion_logits'], dim=1)[0].cpu().numpy().tolist(),
'dominant': torch.argmax(outputs['emotion_logits'], dim=1)[0].item()
},
'intent': {
'probabilities': torch.softmax(outputs['intent_logits'], dim=1)[0].cpu().numpy().tolist(),
'dominant': torch.argmax(outputs['intent_logits'], dim=1)[0].item()
},
'engagement': {
'mean': outputs['engagement_mean'][0].item(),
'uncertainty': outputs['engagement_var'][0].item()
},
'confidence': {
'mean': outputs['confidence_mean'][0].item(),
'uncertainty': outputs['confidence_var'][0].item()
},
'modality_importance': outputs['modality_importance'][0].cpu().numpy().tolist(),
'processing_time': time.time() - start_time,
'model_version': model_version
}
# Cache result
redis_client.setex(cache_key, 3600, json.dumps(result)) # Cache for 1 hour
REQUEST_COUNT.labels(endpoint='/analyze/frame', status='success').inc()
return result
except Exception as e:
REQUEST_COUNT.labels(endpoint='/analyze/frame', status='error').inc()
logger.error(f"Analysis error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
@app.websocket("/ws/analyze/{session_id}")
async def websocket_analyze(websocket: WebSocket, session_id: str):
"""Real-time streaming analysis via WebSocket"""
await manager.connect(websocket, session_id)
try:
while True:
# Receive data
data = await websocket.receive_json()
# Process in background
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
executor,
process_streaming_data,
data,
session_id
)
# Send result
await manager.send_personal_message(json.dumps(result), session_id)
# Update session stats
manager.session_data[session_id]['frames_processed'] += 1
manager.session_data[session_id]['last_activity'] = time.time()
except WebSocketDisconnect:
manager.disconnect(session_id)
except Exception as e:
logger.error(f"WebSocket error for {session_id}: {str(e)}")
await manager.send_personal_message(json.dumps({"error": str(e)}), session_id)
manager.disconnect(session_id)
def process_streaming_data(data, session_id):
"""Process streaming data in background thread"""
# Similar to analyze_frame but optimized for streaming
model_info = model_registry[current_model_version]
model = model_info['model']
# Process data (simplified for demo)
result = {
'session_id': session_id,
'timestamp': time.time(),
'emotion': {'dominant': 0}, # Placeholder
'engagement': 0.5
}
return result
@app.get("/health")
async def health_check():
"""Advanced health check with system metrics"""
return {
"status": "healthy",
"version": current_model_version,
"device": str(device),
"active_connections": len(manager.active_connections),
"model_versions": list(model_registry.keys()),
"redis_connected": redis_client.ping() if redis_client else False,
"timestamp": time.time()
}
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint"""
return StreamingResponse(
prom.generate_latest(),
media_type="text/plain"
)
@app.post("/models/deploy/{version}")
async def deploy_model(version: str, background_tasks: BackgroundTasks):
"""Deploy a new model version (admin endpoint)"""
if version not in model_registry:
raise HTTPException(status_code=404, detail=f"Model version {version} not found")
global current_model_version
current_model_version = version
# Background task to update metrics
background_tasks.add_task(update_model_metrics, version)
return {"message": f"Deployed model version {version}"}
def update_model_metrics(version):
"""Update Prometheus metrics for new model version"""
info = model_registry[version]
MODEL_VERSIONS.labels(version=version, accuracy=info['accuracy']).set(1)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
workers=4, # Multiple workers for better performance
loop="uvloop" # Faster event loop
)