Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Backend for Respiratory Symptom Analysis | |
| Updated for 39% F1-Macro Model (4 symptoms, no CBAM) | |
| Deployed on HuggingFace Spaces for use with Netlify frontend | |
| Version: 3.0.0 | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| import torch.nn as nn | |
| import json | |
| import numpy as np | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Any | |
| import time | |
| import warnings | |
| # Import your preprocessing module | |
| from audio_preprocessing import RespiratoryAudioPreprocessor | |
| warnings.filterwarnings('ignore') | |
| # =================== YOUR EXACT MODEL ARCHITECTURE =================== | |
| class LightweightMultiSymptomClassifier(nn.Module): | |
| """ | |
| Exact model architecture from your 39% F1-Macro training | |
| 4 symptoms: fever, cold, fatigue, cough | |
| No CBAM, simplified CNN architecture | |
| """ | |
| def __init__(self, num_classes=4, dropout=0.5): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| # Convolutional backbone | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(1, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| self.conv3 = nn.Sequential( | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.Conv2d(128, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| self.conv4 = nn.Sequential( | |
| nn.Conv2d(128, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((1, 1)) | |
| ) | |
| # Shared feature layer | |
| self.shared_fc = nn.Sequential( | |
| nn.Linear(256, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ) | |
| # Individual symptom heads | |
| self.symptom_heads = nn.ModuleList([ | |
| nn.Linear(128, 1) for _ in range(num_classes) | |
| ]) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| x = self.conv4(x) | |
| x = x.view(x.size(0), -1) | |
| shared_features = self.shared_fc(x) | |
| outputs = [] | |
| for head in self.symptom_heads: | |
| outputs.append(head(shared_features)) | |
| logits = torch.cat(outputs, dim=1) | |
| return logits | |
| class OptimizedInferenceModel(nn.Module): | |
| """ | |
| Inference wrapper with custom thresholds | |
| """ | |
| def __init__(self, base_model, target_symptoms, confidence_thresholds): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.target_symptoms = target_symptoms | |
| # Convert thresholds to tensor | |
| self.register_buffer('threshold_tensor', | |
| torch.tensor([confidence_thresholds[symptom] | |
| for symptom in target_symptoms], dtype=torch.float32)) | |
| def forward(self, x): | |
| # Get logits from base model | |
| logits = self.base_model(x) | |
| # Convert to probabilities | |
| probs = torch.sigmoid(logits) | |
| # Apply custom thresholds | |
| preds = (probs >= self.threshold_tensor).float() | |
| return { | |
| 'probabilities': probs, | |
| 'predictions': preds, | |
| 'logits': logits | |
| } | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="🫁 Respiratory Symptom Analysis API v3.0", | |
| description="AI-powered respiratory symptom detection (39% F1-Macro model)", | |
| version="3.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class RespiratoryAnalysisService: | |
| """ | |
| Service class for respiratory symptom analysis with 39% F1-Macro model | |
| """ | |
| def __init__(self, model_dir: str = "deployment_model"): | |
| """Initialize the service with model and configuration""" | |
| self.model_dir = Path(model_dir) | |
| self.model = None | |
| self.config = None | |
| self.preprocessor = None | |
| self.weights_loaded = False | |
| self.neutral_threshold = 0.35 | |
| # Load configuration and model | |
| self.load_config() | |
| self.create_and_load_model() | |
| self.setup_preprocessor() | |
| def load_config(self): | |
| """Load configuration""" | |
| config_path = self.model_dir / "model_config.json" | |
| try: | |
| if config_path.exists(): | |
| with open(config_path, 'r') as f: | |
| self.config = json.load(f) | |
| if 'symptom_colors' not in self.config: | |
| self.config['symptom_colors'] = { | |
| 'fever': '#FF6B6B', | |
| 'cold': '#4ECDC4', | |
| 'fatigue': '#FFEAA7', | |
| 'cough': '#DDA0DD' | |
| } | |
| print("⚠️ Added missing symptom_colors to config") | |
| print(f"✅ Configuration loaded from {config_path}") | |
| else: | |
| # Default configuration for 4-symptom model | |
| self.config = { | |
| 'target_symptoms': ['fever', 'cold', 'fatigue', 'cough'], | |
| 'symptom_display_names': { | |
| 'fever': 'Fever', | |
| 'cold': 'Cold/Runny Nose', | |
| 'fatigue': 'Fatigue', | |
| 'cough': 'Persistent Cough' | |
| }, | |
| 'confidence_thresholds': { | |
| 'fever': 0.5, | |
| 'cold': 0.5, | |
| 'fatigue': 0.5, | |
| 'cough': 0.5 | |
| }, | |
| 'symptom_colors': { | |
| 'fever': '#FF6B6B', | |
| 'cold': '#4ECDC4', | |
| 'fatigue': '#FFEAA7', | |
| 'cough': '#DDA0DD' | |
| }, | |
| 'model_version': '3.0_39percent_f1', | |
| 'num_classes': 4, | |
| 'dropout': 0.5 | |
| } | |
| print("⚠️ Using default configuration") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load config: {str(e)}") | |
| def create_and_load_model(self): | |
| """Create model and load weights""" | |
| try: | |
| # Create base model | |
| base_model = LightweightMultiSymptomClassifier( | |
| num_classes=self.config['num_classes'], | |
| dropout=self.config['dropout'] | |
| ) | |
| print("🔍 Searching for model weight files...") | |
| # Priority order for loading weights - UPDATED FILE NAMES | |
| weight_files_to_try = [ | |
| (self.model_dir / "model_state_dict.pt", "Base Model State Dict"), | |
| (self.model_dir / "model_quantized_state_dict.pt", "Quantized State Dict"), | |
| (self.model_dir / "best_model.pt", "Best Checkpoint"), | |
| ] | |
| for weight_file, model_type in weight_files_to_try: | |
| if weight_file.exists(): | |
| file_size = weight_file.stat().st_size / (1024*1024) | |
| print(f"📁 Found {model_type}: {weight_file} ({file_size:.1f}MB)") | |
| try: | |
| checkpoint = torch.load(weight_file, map_location='cpu', weights_only=False) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| # Assume it's a pure state dict | |
| state_dict = checkpoint | |
| else: | |
| print(f"⚠️ Unexpected checkpoint format, skipping...") | |
| continue | |
| # Load state dict | |
| missing, unexpected = base_model.load_state_dict(state_dict, strict=False) | |
| loaded_keys = len(state_dict) - len(missing) | |
| total_keys = len(base_model.state_dict()) | |
| load_percentage = (loaded_keys / total_keys) * 100 | |
| print(f" 📊 Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)") | |
| if missing: | |
| print(f" ⚠️ Missing keys: {len(missing)}") | |
| if unexpected: | |
| print(f" ⚠️ Unexpected keys: {len(unexpected)}") | |
| if load_percentage > 90: # Require at least 90% match | |
| self.weights_loaded = True | |
| print(f"✅ Successfully loaded {model_type}") | |
| break | |
| else: | |
| print(f"⚠️ Only {load_percentage:.1f}% loaded, trying next file...") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load {model_type}: {str(e)}") | |
| continue | |
| if not self.weights_loaded: | |
| print("\n❌ WARNING: Using random model weights!") | |
| print("❌ All predictions will be random") | |
| print(f"❌ Expected model files in: {self.model_dir}/") | |
| print("❌ Required files:") | |
| print(" - model_state_dict.pt (recommended)") | |
| print(" - model_quantized_state_dict.pt (alternative)") | |
| print(" - best_model.pt (alternative)") | |
| else: | |
| print(f"✅ Model ready with trained weights") | |
| # Wrap in inference model with thresholds | |
| self.model = OptimizedInferenceModel( | |
| base_model, | |
| self.config['target_symptoms'], | |
| self.config['confidence_thresholds'] | |
| ) | |
| self.model.eval() | |
| # CPU optimization | |
| torch.set_num_threads(4) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to create/load model: {str(e)}") | |
| def setup_preprocessor(self): | |
| """Initialize audio preprocessor""" | |
| self.preprocessor = RespiratoryAudioPreprocessor() | |
| print("✅ Audio preprocessor initialized") | |
| def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]: | |
| """Predict respiratory symptoms""" | |
| try: | |
| start_time = time.time() | |
| # Preprocess audio | |
| tensor_input = self.preprocessor.preprocess_audio(audio_file_path) | |
| preprocessing_time = time.time() - start_time | |
| # Run inference | |
| inference_start = time.time() | |
| with torch.no_grad(): | |
| outputs = self.model(tensor_input) | |
| inference_time = time.time() - inference_start | |
| # Parse outputs | |
| probabilities = outputs['probabilities'].squeeze().detach().cpu().numpy() | |
| # Convert numpy types to Python types | |
| probabilities = probabilities.astype(float).tolist() | |
| # Detect symptoms | |
| detected_symptoms = [] | |
| for i, symptom in enumerate(self.config['target_symptoms']): | |
| prob = float(probabilities[i]) | |
| threshold = float(self.config['confidence_thresholds'][symptom]) | |
| effective_threshold = max(threshold, self.neutral_threshold) | |
| if prob >= effective_threshold: | |
| detected_symptoms.append({ | |
| 'symptom': symptom, | |
| 'display_name': self.config['symptom_display_names'][symptom], | |
| 'confidence': prob, | |
| 'color': self.config['symptom_colors'][symptom], | |
| 'threshold_used': effective_threshold | |
| }) | |
| # Determine health status | |
| max_confidence = max(probabilities) | |
| if not detected_symptoms: | |
| if max_confidence < self.neutral_threshold: | |
| health_status = "healthy" | |
| status_message = "No symptoms detected - appears healthy" | |
| else: | |
| health_status = "inconclusive" | |
| status_message = "Some patterns detected but below confidence threshold" | |
| else: | |
| health_status = "symptoms_detected" | |
| status_message = f"{len(detected_symptoms)} symptom(s) detected" | |
| # Format results | |
| results = { | |
| 'detected_symptoms': detected_symptoms, | |
| 'all_symptoms': {}, | |
| 'summary': { | |
| 'total_detected': len(detected_symptoms), | |
| 'highest_confidence': max([s['confidence'] for s in detected_symptoms], default=0.0), | |
| 'max_overall_confidence': float(max_confidence), | |
| 'status': health_status, | |
| 'status_message': status_message, | |
| 'neutral_threshold': float(self.neutral_threshold), | |
| 'weights_status': 'trained' if self.weights_loaded else 'random' | |
| }, | |
| 'recommendations': self._get_recommendations(health_status, detected_symptoms), | |
| 'health_classification': health_status, | |
| 'processing_info': { | |
| 'preprocessing_time_ms': round(preprocessing_time * 1000, 1), | |
| 'inference_time_ms': round(inference_time * 1000, 1), | |
| 'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1), | |
| 'model_weights_loaded': self.weights_loaded, | |
| 'model_version': '3.0_39percent_f1' | |
| } | |
| } | |
| # Add all symptoms details | |
| for i, symptom in enumerate(self.config['target_symptoms']): | |
| prob = float(probabilities[i]) | |
| threshold = float(self.config['confidence_thresholds'][symptom]) | |
| effective_threshold = max(threshold, self.neutral_threshold) | |
| results['all_symptoms'][symptom] = { | |
| 'display_name': self.config['symptom_display_names'][symptom], | |
| 'confidence': prob, | |
| 'detected': prob >= effective_threshold, | |
| 'original_threshold': threshold, | |
| 'effective_threshold': effective_threshold, | |
| 'color': self.config['symptom_colors'][symptom] | |
| } | |
| return results | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| def _get_recommendations(self, health_status, detected_symptoms): | |
| """Generate recommendations based on health status""" | |
| recommendations = [] | |
| if not self.weights_loaded: | |
| recommendations.append("⚠️ DEVELOPMENT MODE: Model using random weights - results not valid") | |
| if health_status == "healthy": | |
| recommendations.extend([ | |
| "✅ No significant respiratory symptoms detected", | |
| "Your cough patterns appear normal and healthy", | |
| "Continue maintaining good respiratory health practices", | |
| "This screening is for informational purposes only" | |
| ]) | |
| elif health_status == "inconclusive": | |
| recommendations.extend([ | |
| "⚠️ Some respiratory patterns detected but below confidence threshold", | |
| "Consider monitoring your symptoms over the next few days", | |
| "If symptoms persist or worsen, consult a healthcare provider", | |
| "This AI screening should not replace professional medical advice" | |
| ]) | |
| elif len(detected_symptoms) == 1: | |
| symptom_name = detected_symptoms[0]['display_name'] | |
| confidence = detected_symptoms[0]['confidence'] | |
| recommendations.extend([ | |
| f"🔍 Detected: {symptom_name} (confidence: {confidence:.1%})", | |
| "Monitor this symptom and note any changes", | |
| "Consider consulting a healthcare provider if symptoms persist", | |
| "This AI screening should not replace professional medical advice" | |
| ]) | |
| else: | |
| symptom_names = [s['display_name'] for s in detected_symptoms] | |
| recommendations.extend([ | |
| f"🚨 Multiple symptoms detected: {', '.join(symptom_names)}", | |
| "Multiple symptoms may indicate a need for medical attention", | |
| "Please consult a healthcare provider for proper evaluation", | |
| "This AI screening should not replace professional medical advice" | |
| ]) | |
| return recommendations | |
| # Initialize service | |
| print("🚀 Initializing Respiratory Analysis Service v3.0...") | |
| try: | |
| service = RespiratoryAnalysisService() | |
| print("✅ Service initialized successfully!") | |
| print(f" Model: 39% F1-Macro (4 symptoms)") | |
| print(f" Weights loaded: {'Yes' if service.weights_loaded else 'No'}") | |
| except Exception as e: | |
| print(f"❌ Service initialization failed: {str(e)}") | |
| service = None | |
| # =================== API ROUTES =================== | |
| async def root(): | |
| """Root endpoint""" | |
| if service is None: | |
| return {"service": "Respiratory Symptom Analysis API", "version": "3.0.0", "status": "error"} | |
| return { | |
| "service": "Respiratory Symptom Analysis API", | |
| "version": "3.0.0", | |
| "model_version": "39% F1-Macro (4 symptoms)", | |
| "status": "active", | |
| "model_status": "trained_weights" if service.weights_loaded else "random_weights", | |
| "supported_symptoms": service.config['target_symptoms'], | |
| "endpoints": { | |
| "analyze": "/analyze", | |
| "health": "/health", | |
| "info": "/info", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| model_files_status = { | |
| "model_base": (Path("deployment_model") / "model_base.pt").exists(), | |
| "model_inference": (Path("deployment_model") / "model_inference.pt").exists(), | |
| "model_quantized": (Path("deployment_model") / "model_quantized.pt").exists(), | |
| "model_torchscript": (Path("deployment_model") / "model_torchscript.pt").exists(), | |
| "config": (Path("deployment_model") / "model_config.json").exists() | |
| } | |
| return { | |
| "status": "healthy" if service is not None else "unhealthy", | |
| "timestamp": time.time(), | |
| "service_ready": service is not None, | |
| "model_loaded": service.model is not None if service else False, | |
| "model_weights_status": "trained" if (service and service.weights_loaded) else "random", | |
| "model_files_available": model_files_status, | |
| "api_version": "3.0.0" | |
| } | |
| async def get_info(): | |
| """Get model information""" | |
| if service is None: | |
| return {"error": "Service not initialized"} | |
| return { | |
| "model_info": { | |
| "version": "3.0_39percent_f1", | |
| "architecture": "LightweightMultiSymptomClassifier (no CBAM)", | |
| "target_symptoms": service.config['target_symptoms'], | |
| "symptom_display_names": service.config['symptom_display_names'], | |
| "confidence_thresholds": service.config['confidence_thresholds'], | |
| "weights_loaded": service.weights_loaded, | |
| "neutral_threshold": service.neutral_threshold | |
| }, | |
| "preprocessing_info": service.preprocessor.get_preprocessing_info(), | |
| "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a", "webm"], | |
| "max_duration": "30 seconds", | |
| "max_file_size": "10MB", | |
| "api_version": "3.0.0" | |
| } | |
| async def analyze_audio(audio_file: UploadFile = File(...)): | |
| """ | |
| Analyze audio file for respiratory symptoms | |
| Returns detected symptoms with confidence scores and health classification | |
| """ | |
| if service is None: | |
| raise HTTPException(status_code=503, detail="Service not available") | |
| # Validate file type | |
| allowed_types = ['audio/wav', 'audio/mpeg', 'audio/mp3', 'audio/flac', | |
| 'audio/ogg', 'audio/x-m4a', 'audio/mp4', 'audio/webm'] | |
| if audio_file.content_type not in allowed_types: | |
| raise HTTPException(status_code=400, | |
| detail=f"Unsupported format: {audio_file.content_type}") | |
| # Validate file size | |
| content = await audio_file.read() | |
| if len(content) > 10 * 1024 * 1024: # 10MB | |
| raise HTTPException(status_code=400, detail="File too large. Maximum: 10MB") | |
| try: | |
| # Save uploaded file temporarily | |
| file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file: | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| # Analyze audio | |
| results = service.predict_symptoms(temp_file_path) | |
| # Clean up | |
| os.unlink(temp_file_path) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "success": True, | |
| "data": results, | |
| "metadata": { | |
| "filename": audio_file.filename, | |
| "file_size_bytes": len(content), | |
| "content_type": audio_file.content_type, | |
| "timestamp": time.time(), | |
| "api_version": "3.0.0" | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| if 'temp_file_path' in locals(): | |
| try: | |
| os.unlink(temp_file_path) | |
| except: | |
| pass | |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False) | |