""" FastAPI Backend for Respiratory Symptom Analysis Updated for 39% F1-Macro Model (4 symptoms, no CBAM) Deployed on HuggingFace Spaces for use with Netlify frontend Version: 3.0.0 """ from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import torch import torch.nn as nn import json import numpy as np import tempfile import os from pathlib import Path from typing import Dict, List, Any import time import warnings # Import your preprocessing module from audio_preprocessing import RespiratoryAudioPreprocessor warnings.filterwarnings('ignore') # =================== YOUR EXACT MODEL ARCHITECTURE =================== class LightweightMultiSymptomClassifier(nn.Module): """ Exact model architecture from your 39% F1-Macro training 4 symptoms: fever, cold, fatigue, cough No CBAM, simplified CNN architecture """ def __init__(self, num_classes=4, dropout=0.5): super().__init__() self.num_classes = num_classes # Convolutional backbone self.conv1 = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) ) self.conv2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2) ) self.conv3 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2) ) self.conv4 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)) ) # Shared feature layer self.shared_fc = nn.Sequential( nn.Linear(256, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout) ) # Individual symptom heads self.symptom_heads = nn.ModuleList([ nn.Linear(128, 1) for _ in range(num_classes) ]) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.view(x.size(0), -1) shared_features = self.shared_fc(x) outputs = [] for head in self.symptom_heads: outputs.append(head(shared_features)) logits = torch.cat(outputs, dim=1) return logits class OptimizedInferenceModel(nn.Module): """ Inference wrapper with custom thresholds """ def __init__(self, base_model, target_symptoms, confidence_thresholds): super().__init__() self.base_model = base_model self.target_symptoms = target_symptoms # Convert thresholds to tensor self.register_buffer('threshold_tensor', torch.tensor([confidence_thresholds[symptom] for symptom in target_symptoms], dtype=torch.float32)) def forward(self, x): # Get logits from base model logits = self.base_model(x) # Convert to probabilities probs = torch.sigmoid(logits) # Apply custom thresholds preds = (probs >= self.threshold_tensor).float() return { 'probabilities': probs, 'predictions': preds, 'logits': logits } # Initialize FastAPI app app = FastAPI( title="🫁 Respiratory Symptom Analysis API v3.0", description="AI-powered respiratory symptom detection (39% F1-Macro model)", version="3.0.0", docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class RespiratoryAnalysisService: """ Service class for respiratory symptom analysis with 39% F1-Macro model """ def __init__(self, model_dir: str = "deployment_model"): """Initialize the service with model and configuration""" self.model_dir = Path(model_dir) self.model = None self.config = None self.preprocessor = None self.weights_loaded = False self.neutral_threshold = 0.35 # Load configuration and model self.load_config() self.create_and_load_model() self.setup_preprocessor() def load_config(self): """Load configuration""" config_path = self.model_dir / "model_config.json" try: if config_path.exists(): with open(config_path, 'r') as f: self.config = json.load(f) if 'symptom_colors' not in self.config: self.config['symptom_colors'] = { 'fever': '#FF6B6B', 'cold': '#4ECDC4', 'fatigue': '#FFEAA7', 'cough': '#DDA0DD' } print("āš ļø Added missing symptom_colors to config") print(f"āœ… Configuration loaded from {config_path}") else: # Default configuration for 4-symptom model self.config = { 'target_symptoms': ['fever', 'cold', 'fatigue', 'cough'], 'symptom_display_names': { 'fever': 'Fever', 'cold': 'Cold/Runny Nose', 'fatigue': 'Fatigue', 'cough': 'Persistent Cough' }, 'confidence_thresholds': { 'fever': 0.5, 'cold': 0.5, 'fatigue': 0.5, 'cough': 0.5 }, 'symptom_colors': { 'fever': '#FF6B6B', 'cold': '#4ECDC4', 'fatigue': '#FFEAA7', 'cough': '#DDA0DD' }, 'model_version': '3.0_39percent_f1', 'num_classes': 4, 'dropout': 0.5 } print("āš ļø Using default configuration") except Exception as e: raise RuntimeError(f"Failed to load config: {str(e)}") def create_and_load_model(self): """Create model and load weights""" try: # Create base model base_model = LightweightMultiSymptomClassifier( num_classes=self.config['num_classes'], dropout=self.config['dropout'] ) print("šŸ” Searching for model weight files...") # Priority order for loading weights - UPDATED FILE NAMES weight_files_to_try = [ (self.model_dir / "model_state_dict.pt", "Base Model State Dict"), (self.model_dir / "model_quantized_state_dict.pt", "Quantized State Dict"), (self.model_dir / "best_model.pt", "Best Checkpoint"), ] for weight_file, model_type in weight_files_to_try: if weight_file.exists(): file_size = weight_file.stat().st_size / (1024*1024) print(f"šŸ“ Found {model_type}: {weight_file} ({file_size:.1f}MB)") try: checkpoint = torch.load(weight_file, map_location='cpu', weights_only=False) # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: # Assume it's a pure state dict state_dict = checkpoint else: print(f"āš ļø Unexpected checkpoint format, skipping...") continue # Load state dict missing, unexpected = base_model.load_state_dict(state_dict, strict=False) loaded_keys = len(state_dict) - len(missing) total_keys = len(base_model.state_dict()) load_percentage = (loaded_keys / total_keys) * 100 print(f" šŸ“Š Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)") if missing: print(f" āš ļø Missing keys: {len(missing)}") if unexpected: print(f" āš ļø Unexpected keys: {len(unexpected)}") if load_percentage > 90: # Require at least 90% match self.weights_loaded = True print(f"āœ… Successfully loaded {model_type}") break else: print(f"āš ļø Only {load_percentage:.1f}% loaded, trying next file...") except Exception as e: print(f"āš ļø Failed to load {model_type}: {str(e)}") continue if not self.weights_loaded: print("\nāŒ WARNING: Using random model weights!") print("āŒ All predictions will be random") print(f"āŒ Expected model files in: {self.model_dir}/") print("āŒ Required files:") print(" - model_state_dict.pt (recommended)") print(" - model_quantized_state_dict.pt (alternative)") print(" - best_model.pt (alternative)") else: print(f"āœ… Model ready with trained weights") # Wrap in inference model with thresholds self.model = OptimizedInferenceModel( base_model, self.config['target_symptoms'], self.config['confidence_thresholds'] ) self.model.eval() # CPU optimization torch.set_num_threads(4) except Exception as e: raise RuntimeError(f"Failed to create/load model: {str(e)}") def setup_preprocessor(self): """Initialize audio preprocessor""" self.preprocessor = RespiratoryAudioPreprocessor() print("āœ… Audio preprocessor initialized") def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]: """Predict respiratory symptoms""" try: start_time = time.time() # Preprocess audio tensor_input = self.preprocessor.preprocess_audio(audio_file_path) preprocessing_time = time.time() - start_time # Run inference inference_start = time.time() with torch.no_grad(): outputs = self.model(tensor_input) inference_time = time.time() - inference_start # Parse outputs probabilities = outputs['probabilities'].squeeze().detach().cpu().numpy() # Convert numpy types to Python types probabilities = probabilities.astype(float).tolist() # Detect symptoms detected_symptoms = [] for i, symptom in enumerate(self.config['target_symptoms']): prob = float(probabilities[i]) threshold = float(self.config['confidence_thresholds'][symptom]) effective_threshold = max(threshold, self.neutral_threshold) if prob >= effective_threshold: detected_symptoms.append({ 'symptom': symptom, 'display_name': self.config['symptom_display_names'][symptom], 'confidence': prob, 'color': self.config['symptom_colors'][symptom], 'threshold_used': effective_threshold }) # Determine health status max_confidence = max(probabilities) if not detected_symptoms: if max_confidence < self.neutral_threshold: health_status = "healthy" status_message = "No symptoms detected - appears healthy" else: health_status = "inconclusive" status_message = "Some patterns detected but below confidence threshold" else: health_status = "symptoms_detected" status_message = f"{len(detected_symptoms)} symptom(s) detected" # Format results results = { 'detected_symptoms': detected_symptoms, 'all_symptoms': {}, 'summary': { 'total_detected': len(detected_symptoms), 'highest_confidence': max([s['confidence'] for s in detected_symptoms], default=0.0), 'max_overall_confidence': float(max_confidence), 'status': health_status, 'status_message': status_message, 'neutral_threshold': float(self.neutral_threshold), 'weights_status': 'trained' if self.weights_loaded else 'random' }, 'recommendations': self._get_recommendations(health_status, detected_symptoms), 'health_classification': health_status, 'processing_info': { 'preprocessing_time_ms': round(preprocessing_time * 1000, 1), 'inference_time_ms': round(inference_time * 1000, 1), 'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1), 'model_weights_loaded': self.weights_loaded, 'model_version': '3.0_39percent_f1' } } # Add all symptoms details for i, symptom in enumerate(self.config['target_symptoms']): prob = float(probabilities[i]) threshold = float(self.config['confidence_thresholds'][symptom]) effective_threshold = max(threshold, self.neutral_threshold) results['all_symptoms'][symptom] = { 'display_name': self.config['symptom_display_names'][symptom], 'confidence': prob, 'detected': prob >= effective_threshold, 'original_threshold': threshold, 'effective_threshold': effective_threshold, 'color': self.config['symptom_colors'][symptom] } return results except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") def _get_recommendations(self, health_status, detected_symptoms): """Generate recommendations based on health status""" recommendations = [] if not self.weights_loaded: recommendations.append("āš ļø DEVELOPMENT MODE: Model using random weights - results not valid") if health_status == "healthy": recommendations.extend([ "āœ… No significant respiratory symptoms detected", "Your cough patterns appear normal and healthy", "Continue maintaining good respiratory health practices", "This screening is for informational purposes only" ]) elif health_status == "inconclusive": recommendations.extend([ "āš ļø Some respiratory patterns detected but below confidence threshold", "Consider monitoring your symptoms over the next few days", "If symptoms persist or worsen, consult a healthcare provider", "This AI screening should not replace professional medical advice" ]) elif len(detected_symptoms) == 1: symptom_name = detected_symptoms[0]['display_name'] confidence = detected_symptoms[0]['confidence'] recommendations.extend([ f"šŸ” Detected: {symptom_name} (confidence: {confidence:.1%})", "Monitor this symptom and note any changes", "Consider consulting a healthcare provider if symptoms persist", "This AI screening should not replace professional medical advice" ]) else: symptom_names = [s['display_name'] for s in detected_symptoms] recommendations.extend([ f"🚨 Multiple symptoms detected: {', '.join(symptom_names)}", "Multiple symptoms may indicate a need for medical attention", "Please consult a healthcare provider for proper evaluation", "This AI screening should not replace professional medical advice" ]) return recommendations # Initialize service print("šŸš€ Initializing Respiratory Analysis Service v3.0...") try: service = RespiratoryAnalysisService() print("āœ… Service initialized successfully!") print(f" Model: 39% F1-Macro (4 symptoms)") print(f" Weights loaded: {'Yes' if service.weights_loaded else 'No'}") except Exception as e: print(f"āŒ Service initialization failed: {str(e)}") service = None # =================== API ROUTES =================== @app.get("/") async def root(): """Root endpoint""" if service is None: return {"service": "Respiratory Symptom Analysis API", "version": "3.0.0", "status": "error"} return { "service": "Respiratory Symptom Analysis API", "version": "3.0.0", "model_version": "39% F1-Macro (4 symptoms)", "status": "active", "model_status": "trained_weights" if service.weights_loaded else "random_weights", "supported_symptoms": service.config['target_symptoms'], "endpoints": { "analyze": "/analyze", "health": "/health", "info": "/info", "docs": "/docs" } } @app.get("/health") async def health_check(): """Health check endpoint""" model_files_status = { "model_base": (Path("deployment_model") / "model_base.pt").exists(), "model_inference": (Path("deployment_model") / "model_inference.pt").exists(), "model_quantized": (Path("deployment_model") / "model_quantized.pt").exists(), "model_torchscript": (Path("deployment_model") / "model_torchscript.pt").exists(), "config": (Path("deployment_model") / "model_config.json").exists() } return { "status": "healthy" if service is not None else "unhealthy", "timestamp": time.time(), "service_ready": service is not None, "model_loaded": service.model is not None if service else False, "model_weights_status": "trained" if (service and service.weights_loaded) else "random", "model_files_available": model_files_status, "api_version": "3.0.0" } @app.get("/info") async def get_info(): """Get model information""" if service is None: return {"error": "Service not initialized"} return { "model_info": { "version": "3.0_39percent_f1", "architecture": "LightweightMultiSymptomClassifier (no CBAM)", "target_symptoms": service.config['target_symptoms'], "symptom_display_names": service.config['symptom_display_names'], "confidence_thresholds": service.config['confidence_thresholds'], "weights_loaded": service.weights_loaded, "neutral_threshold": service.neutral_threshold }, "preprocessing_info": service.preprocessor.get_preprocessing_info(), "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a", "webm"], "max_duration": "30 seconds", "max_file_size": "10MB", "api_version": "3.0.0" } @app.post("/analyze") async def analyze_audio(audio_file: UploadFile = File(...)): """ Analyze audio file for respiratory symptoms Returns detected symptoms with confidence scores and health classification """ if service is None: raise HTTPException(status_code=503, detail="Service not available") # Validate file type allowed_types = ['audio/wav', 'audio/mpeg', 'audio/mp3', 'audio/flac', 'audio/ogg', 'audio/x-m4a', 'audio/mp4', 'audio/webm'] if audio_file.content_type not in allowed_types: raise HTTPException(status_code=400, detail=f"Unsupported format: {audio_file.content_type}") # Validate file size content = await audio_file.read() if len(content) > 10 * 1024 * 1024: # 10MB raise HTTPException(status_code=400, detail="File too large. Maximum: 10MB") try: # Save uploaded file temporarily file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav' with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file: temp_file.write(content) temp_file_path = temp_file.name # Analyze audio results = service.predict_symptoms(temp_file_path) # Clean up os.unlink(temp_file_path) return JSONResponse( status_code=200, content={ "success": True, "data": results, "metadata": { "filename": audio_file.filename, "file_size_bytes": len(content), "content_type": audio_file.content_type, "timestamp": time.time(), "api_version": "3.0.0" } } ) except Exception as e: if 'temp_file_path' in locals(): try: os.unlink(temp_file_path) except: pass raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)