Kalpokoch commited on
Commit
77589df
·
verified ·
1 Parent(s): ed9b2d0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +349 -129
main.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  FastAPI Backend for Respiratory Symptom Analysis
 
3
  Deployed on HuggingFace Spaces for use with Netlify frontend
4
- Fixed for model loading compatibility issues
5
  """
6
 
7
  from fastapi import FastAPI, File, UploadFile, HTTPException
@@ -138,7 +138,7 @@ class PurePyTorchInferenceModel(nn.Module):
138
  app = FastAPI(
139
  title="🫁 Respiratory Symptom Analysis API",
140
  description="AI-powered respiratory symptom detection from cough audio",
141
- version="2.0.0",
142
  docs_url="/docs",
143
  redoc_url="/redoc"
144
  )
@@ -154,7 +154,7 @@ app.add_middleware(
154
 
155
  class RespiratoryAnalysisService:
156
  """
157
- Service class for respiratory symptom analysis
158
  """
159
 
160
  def __init__(self, config_path: str = "optimized_model_cpu/model_config.json"):
@@ -163,6 +163,8 @@ class RespiratoryAnalysisService:
163
  self.model = None
164
  self.config = None
165
  self.preprocessor = None
 
 
166
 
167
  # Load configuration and model
168
  self.load_config()
@@ -196,7 +198,7 @@ class RespiratoryAnalysisService:
196
  'fever': '#FF6B6B', 'cold': '#4ECDC4', 'sorethroat': '#45B7D1',
197
  'lossofsmell': '#96CEB4', 'fatigue': '#FFEAA7', 'cough': '#DDA0DD'
198
  },
199
- 'model_version': '2.0',
200
  'optimization_settings': {'torch_threads': 4}
201
  }
202
  print("⚠️ Using default configuration")
@@ -205,7 +207,7 @@ class RespiratoryAnalysisService:
205
  raise RuntimeError(f"Failed to load config: {str(e)}")
206
 
207
  def create_and_load_model(self):
208
- """Create model and try to load weights from available files"""
209
  try:
210
  # Create model with correct architecture
211
  self.model = PurePyTorchInferenceModel(
@@ -213,54 +215,57 @@ class RespiratoryAnalysisService:
213
  confidence_thresholds=self.config['confidence_thresholds']
214
  )
215
 
216
- # Try to load state dict from available files
217
- state_dict_files = [
218
- "optimized_model_cpu/model_pytorch_state_dict.pt",
219
- "optimized_model_cpu/model_quantized_state_dict.pt"
 
 
 
 
 
 
 
 
 
 
220
  ]
221
 
222
- model_loaded = False
223
- for state_dict_file in state_dict_files:
224
- if Path(state_dict_file).exists():
 
 
225
  try:
226
- state_dict = torch.load(state_dict_file, map_location='cpu')
227
-
228
- # Handle different state dict formats
229
- if 'state_dict' in state_dict:
230
- state_dict = state_dict['state_dict']
231
-
232
- # Filter out incompatible keys (like attention layers)
233
- filtered_state_dict = {}
234
- for key, value in state_dict.items():
235
- # Skip attention-related keys that might cause issues
236
- if 'symptom_attention' in key:
237
- continue
238
- # Skip COVID classifier if present
239
- if 'covid_classifier' in key:
240
- continue
241
- filtered_state_dict[key] = value
242
-
243
- # Load with strict=False to ignore missing/extra keys
244
- missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)
245
-
246
- if missing_keys:
247
- print(f"⚠️ Missing keys (will use random initialization): {missing_keys}")
248
- if unexpected_keys:
249
- print(f"⚠️ Unexpected keys (ignored): {unexpected_keys}")
250
-
251
- print(f"✅ Model weights loaded from {state_dict_file}")
252
- model_loaded = True
253
- break
254
-
255
  except Exception as e:
256
- print(f"⚠️ Failed to load {state_dict_file}: {str(e)}")
257
  continue
258
  else:
259
- print(f"⚠️ State dict file not found: {state_dict_file}")
260
 
261
- if not model_loaded:
262
- print("⚠️ No compatible model weights found, using random initialization")
263
- print("⚠️ Model will give random predictions - this is for testing only!")
 
 
 
 
 
 
264
 
265
  # Set model to evaluation mode
266
  self.model.eval()
@@ -271,13 +276,107 @@ class RespiratoryAnalysisService:
271
  except Exception as e:
272
  raise RuntimeError(f"Failed to create/load model: {str(e)}")
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def setup_preprocessor(self):
275
  """Initialize audio preprocessor"""
276
  self.preprocessor = RespiratoryAudioPreprocessor()
277
  print("✅ Audio preprocessor initialized")
278
 
279
  def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]:
280
- """Predict respiratory symptoms from audio file"""
281
  try:
282
  start_time = time.time()
283
 
@@ -289,22 +388,60 @@ class RespiratoryAnalysisService:
289
  inference_start = time.time()
290
  with torch.no_grad():
291
  outputs = self.model(tensor_input)
292
-
293
  inference_time = time.time() - inference_start
294
 
295
  # Parse outputs
296
  probabilities = outputs['probabilities'].squeeze().numpy()
297
- predictions = outputs['predictions'].squeeze().numpy()
298
 
299
- # Format results
300
- results = self.format_results(probabilities, predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- # Add timing info
303
  results['processing_info'] = {
304
  'preprocessing_time_ms': round(preprocessing_time * 1000, 1),
305
  'inference_time_ms': round(inference_time * 1000, 1),
306
  'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1),
307
- 'model_status': 'loaded' if hasattr(self, '_weights_loaded') else 'random_weights'
 
 
308
  }
309
 
310
  return results
@@ -312,176 +449,233 @@ class RespiratoryAnalysisService:
312
  except Exception as e:
313
  raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
314
 
315
- def format_results(self, probabilities: np.ndarray, predictions: np.ndarray) -> Dict[str, Any]:
316
- """Format prediction results for API response"""
 
317
  results = {
318
- 'detected_symptoms': [],
319
  'all_symptoms': {},
320
  'summary': {},
321
- 'recommendations': []
 
322
  }
323
 
324
- # Process each symptom
325
  for i, symptom in enumerate(self.config['target_symptoms']):
326
  prob = float(probabilities[i])
327
- pred = bool(predictions[i])
328
- display_name = self.config['symptom_display_names'][symptom]
329
- threshold = self.config['confidence_thresholds'][symptom]
330
 
331
- # All symptoms with details
332
  results['all_symptoms'][symptom] = {
333
- 'display_name': display_name,
334
  'confidence': prob,
335
- 'detected': pred,
336
- 'threshold': threshold,
 
 
337
  'color': self.config['symptom_colors'][symptom]
338
  }
339
-
340
- # Detected symptoms only
341
- if pred:
342
- results['detected_symptoms'].append({
343
- 'symptom': symptom,
344
- 'display_name': display_name,
345
- 'confidence': prob,
346
- 'color': self.config['symptom_colors'][symptom]
347
- })
348
-
349
- # Sort detected symptoms by confidence
350
- results['detected_symptoms'].sort(key=lambda x: x['confidence'], reverse=True)
351
 
352
- # Generate summary
353
  results['summary'] = {
354
- 'total_detected': len(results['detected_symptoms']),
355
- 'highest_confidence': results['detected_symptoms'][0]['confidence'] if results['detected_symptoms'] else 0.0,
356
- 'status': 'symptoms_detected' if results['detected_symptoms'] else 'no_symptoms'
 
 
 
 
357
  }
358
 
359
- # Generate recommendations
360
- if len(results['detected_symptoms']) == 0:
 
 
 
 
 
 
 
361
  results['recommendations'] = [
362
- "No significant respiratory symptoms detected.",
363
- "Continue monitoring your health.",
364
- "This screening is for informational purposes only."
 
365
  ]
366
- elif len(results['detected_symptoms']) == 1:
367
- symptom_name = results['detected_symptoms'][0]['display_name']
 
368
  results['recommendations'] = [
369
- f"Detected: {symptom_name}",
370
- "Consider monitoring symptoms and consult healthcare provider if symptoms persist.",
371
- "This AI screening should not replace professional medical advice."
 
372
  ]
373
  else:
374
- symptom_names = [s['display_name'] for s in results['detected_symptoms']]
375
  results['recommendations'] = [
376
- f"Multiple symptoms detected: {', '.join(symptom_names)}",
377
- "Please consult a healthcare provider for proper evaluation.",
378
- "This AI screening should not replace professional medical advice."
 
379
  ]
380
 
 
 
 
 
 
 
381
  return results
382
 
383
- # Initialize service
384
- print("🚀 Initializing Respiratory Analysis Service...")
385
  try:
386
  service = RespiratoryAnalysisService()
387
  print("✅ Service initialized successfully!")
 
 
388
  except Exception as e:
389
  print(f"❌ Service initialization failed: {str(e)}")
390
  service = None
391
 
392
- # API Routes
 
393
  @app.get("/")
394
  async def root():
395
- """Root endpoint with API information"""
396
  if service is None:
397
  return {
398
  "service": "Respiratory Symptom Analysis API",
399
- "version": "2.0.0",
400
  "status": "error - service not initialized"
401
  }
402
 
403
  return {
404
- "service": "Respiratory Symptom Analysis API",
405
- "version": "2.0.0",
406
  "status": "active",
 
 
 
407
  "endpoints": {
408
  "analyze": "/analyze",
409
- "health": "/health",
410
  "info": "/info",
411
  "docs": "/docs"
412
  },
413
- "supported_symptoms": list(service.config['target_symptoms']),
414
  "model_info": {
415
  "version": service.config['model_version'],
416
- "optimization": "CPU-optimized"
417
  }
418
  }
419
 
420
  @app.get("/health")
421
  async def health_check():
422
- """Health check endpoint"""
 
 
 
 
 
 
 
 
 
423
  return {
424
  "status": "healthy" if service is not None else "unhealthy",
425
  "timestamp": time.time(),
 
426
  "model_loaded": service.model is not None if service else False,
427
  "config_loaded": service.config is not None if service else False,
428
- "files_available": {
429
- "config": Path("optimized_model_cpu/model_config.json").exists(),
430
- "pytorch_state": Path("optimized_model_cpu/model_pytorch_state_dict.pt").exists(),
431
- "quantized_state": Path("optimized_model_cpu/model_quantized_state_dict.pt").exists()
432
- }
 
 
 
 
433
  }
434
 
435
  @app.get("/info")
436
  async def get_info():
437
- """Get model and service information"""
438
  if service is None:
439
  return {"error": "Service not initialized"}
440
 
441
  return {
442
  "model_info": {
443
- "version": service.config.get('model_version', '2.0'),
444
  "target_symptoms": service.config['target_symptoms'],
445
  "symptom_display_names": service.config['symptom_display_names'],
446
- "confidence_thresholds": service.config['confidence_thresholds']
 
 
 
447
  },
448
  "preprocessing_info": service.preprocessor.get_preprocessing_info(),
449
- "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a"],
450
- "api_version": "2.0.0"
 
 
 
 
 
 
 
 
451
  }
452
 
453
  @app.post("/analyze")
454
  async def analyze_audio(audio_file: UploadFile = File(...)):
455
- """Analyze audio file for respiratory symptoms"""
 
 
 
 
 
 
 
 
456
  if service is None:
457
  raise HTTPException(status_code=503, detail="Service not available")
458
 
459
- # Validate file
460
- allowed_types = ['audio/wav', 'audio/mpeg', 'audio/flac', 'audio/ogg', 'audio/x-m4a', 'audio/mp4']
 
 
 
 
461
  if audio_file.content_type not in allowed_types:
462
  raise HTTPException(
463
- status_code=400,
464
- detail=f"Unsupported format: {audio_file.content_type}"
465
  )
466
 
467
- # Validate size
468
  content = await audio_file.read()
469
- if len(content) > 10 * 1024 * 1024: # 10MB
470
- raise HTTPException(status_code=400, detail="File too large")
 
471
 
472
  try:
473
- # Save temporarily
474
  file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav'
475
  with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file:
476
  temp_file.write(content)
477
  temp_file_path = temp_file.name
478
 
479
- # Analyze
480
  results = service.predict_symptoms(temp_file_path)
481
 
482
- # Cleanup
483
  os.unlink(temp_file_path)
484
 
 
485
  return JSONResponse(
486
  status_code=200,
487
  content={
@@ -490,7 +684,9 @@ async def analyze_audio(audio_file: UploadFile = File(...)):
490
  "metadata": {
491
  "filename": audio_file.filename,
492
  "file_size_bytes": len(content),
493
- "timestamp": time.time()
 
 
494
  }
495
  }
496
  )
@@ -498,13 +694,37 @@ async def analyze_audio(audio_file: UploadFile = File(...)):
498
  except HTTPException:
499
  raise
500
  except Exception as e:
 
501
  if 'temp_file_path' in locals():
502
  try:
503
  os.unlink(temp_file_path)
504
  except:
505
  pass
 
506
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  if __name__ == "__main__":
509
  import uvicorn
510
- uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
 
 
 
 
 
 
 
 
1
  """
2
  FastAPI Backend for Respiratory Symptom Analysis
3
+ Updated with proper model weight loading and health classification system
4
  Deployed on HuggingFace Spaces for use with Netlify frontend
 
5
  """
6
 
7
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
138
  app = FastAPI(
139
  title="🫁 Respiratory Symptom Analysis API",
140
  description="AI-powered respiratory symptom detection from cough audio",
141
+ version="2.1.0",
142
  docs_url="/docs",
143
  redoc_url="/redoc"
144
  )
 
154
 
155
  class RespiratoryAnalysisService:
156
  """
157
+ Enhanced service class for respiratory symptom analysis with proper model loading
158
  """
159
 
160
  def __init__(self, config_path: str = "optimized_model_cpu/model_config.json"):
 
163
  self.model = None
164
  self.config = None
165
  self.preprocessor = None
166
+ self.weights_loaded = False # Track if real weights are loaded
167
+ self.neutral_threshold = 0.35 # Below this = neutral/healthy
168
 
169
  # Load configuration and model
170
  self.load_config()
 
198
  'fever': '#FF6B6B', 'cold': '#4ECDC4', 'sorethroat': '#45B7D1',
199
  'lossofsmell': '#96CEB4', 'fatigue': '#FFEAA7', 'cough': '#DDA0DD'
200
  },
201
+ 'model_version': '2.1',
202
  'optimization_settings': {'torch_threads': 4}
203
  }
204
  print("⚠️ Using default configuration")
 
207
  raise RuntimeError(f"Failed to load config: {str(e)}")
208
 
209
  def create_and_load_model(self):
210
+ """Create model and try to load weights from available files with priority order"""
211
  try:
212
  # Create model with correct architecture
213
  self.model = PurePyTorchInferenceModel(
 
215
  confidence_thresholds=self.config['confidence_thresholds']
216
  )
217
 
218
+ print("🔍 Searching for model weight files...")
219
+
220
+ # ✅ PRIORITY ORDER: Try different model files with detailed logging
221
+ weight_files_to_try = [
222
+ # Highest priority - state dicts (most compatible)
223
+ ("optimized_model_cpu/model_pytorch_state_dict.pt", "PyTorch State Dict", "state_dict"),
224
+ ("optimized_model_cpu/model_quantized_state_dict.pt", "Quantized State Dict", "state_dict"),
225
+
226
+ # Medium priority - full models
227
+ ("optimized_model_cpu/model_pytorch.pt", "Full PyTorch Model", "full_model"),
228
+ ("optimized_model_cpu/model_quantized.pt", "Quantized PyTorch Model", "full_model"),
229
+
230
+ # Lower priority - TorchScript (compatibility issues)
231
+ ("optimized_model_cpu/model_torchscript.pt", "TorchScript Model", "torchscript"),
232
  ]
233
 
234
+ for weight_file, model_type, load_type in weight_files_to_try:
235
+ if Path(weight_file).exists():
236
+ file_size = Path(weight_file).stat().st_size / (1024*1024) # Size in MB
237
+ print(f"📁 Found {model_type}: {weight_file} ({file_size:.1f}MB)")
238
+
239
  try:
240
+ if load_type == "state_dict":
241
+ success = self._load_state_dict(weight_file, model_type)
242
+ elif load_type == "full_model":
243
+ success = self._load_full_model(weight_file, model_type)
244
+ elif load_type == "torchscript":
245
+ success = self._load_torchscript_model(weight_file, model_type)
246
+ else:
247
+ success = False
248
+
249
+ if success:
250
+ self.weights_loaded = True
251
+ print(f"✅ Successfully loaded {model_type}")
252
+ break
253
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  except Exception as e:
255
+ print(f"⚠️ Failed to load {model_type}: {str(e)}")
256
  continue
257
  else:
258
+ print(f" Not found: {weight_file}")
259
 
260
+ if not self.weights_loaded:
261
+ print("\n❌ WARNING: Using random model weights!")
262
+ print(" All predictions will be random (~50% confidence)")
263
+ print("❌ Please check your model files in optimized_model_cpu/")
264
+ print("❌ Expected files:")
265
+ for file_path, _, _ in weight_files_to_try:
266
+ print(f" - {file_path}")
267
+ else:
268
+ print(f"✅ Model ready with trained weights")
269
 
270
  # Set model to evaluation mode
271
  self.model.eval()
 
276
  except Exception as e:
277
  raise RuntimeError(f"Failed to create/load model: {str(e)}")
278
 
279
+ def _load_state_dict(self, weight_file: str, model_type: str) -> bool:
280
+ """Load model from state dict file"""
281
+ try:
282
+ checkpoint = torch.load(weight_file, map_location='cpu')
283
+
284
+ # Handle different checkpoint formats
285
+ if isinstance(checkpoint, dict):
286
+ if 'state_dict' in checkpoint:
287
+ state_dict = checkpoint['state_dict']
288
+ elif 'model_state_dict' in checkpoint:
289
+ state_dict = checkpoint['model_state_dict']
290
+ else:
291
+ state_dict = checkpoint
292
+ else:
293
+ state_dict = checkpoint
294
+
295
+ # Remove any incompatible keys
296
+ filtered_state_dict = {}
297
+ for key, value in state_dict.items():
298
+ # Skip keys that might cause issues
299
+ if any(skip in key for skip in ['symptom_attention', 'covid_classifier', 'aux_']):
300
+ print(f" Skipping incompatible key: {key}")
301
+ continue
302
+ filtered_state_dict[key] = value
303
+
304
+ # Load weights
305
+ missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)
306
+
307
+ # Check if enough weights were loaded
308
+ loaded_keys = len(filtered_state_dict) - len(missing_keys)
309
+ total_keys = len(self.model.state_dict())
310
+ load_percentage = (loaded_keys / total_keys) * 100
311
+
312
+ print(f" 📊 Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)")
313
+
314
+ if missing_keys:
315
+ print(f" ⚠️ Missing keys: {len(missing_keys)} (using random initialization)")
316
+ if unexpected_keys:
317
+ print(f" ⚠️ Unexpected keys: {len(unexpected_keys)} (ignored)")
318
+
319
+ # Consider successful if we loaded most parameters
320
+ return load_percentage > 50
321
+
322
+ except Exception as e:
323
+ print(f" ❌ State dict loading failed: {str(e)}")
324
+ return False
325
+
326
+ def _load_full_model(self, weight_file: str, model_type: str) -> bool:
327
+ """Load full model file"""
328
+ try:
329
+ loaded_model = torch.load(weight_file, map_location='cpu')
330
+
331
+ if hasattr(loaded_model, 'state_dict'):
332
+ # Extract state dict from full model
333
+ state_dict = loaded_model.state_dict()
334
+ return self._load_state_dict_direct(state_dict)
335
+ else:
336
+ # Try to use as state dict directly
337
+ return self._load_state_dict_direct(loaded_model)
338
+
339
+ except Exception as e:
340
+ print(f" ❌ Full model loading failed: {str(e)}")
341
+ return False
342
+
343
+ def _load_torchscript_model(self, weight_file: str, model_type: str) -> bool:
344
+ """Load TorchScript model (with known compatibility issues)"""
345
+ try:
346
+ scripted_model = torch.jit.load(weight_file, map_location='cpu')
347
+ scripted_model.eval()
348
+
349
+ # Replace the model entirely with TorchScript version
350
+ self.model = scripted_model
351
+ print(f" ✅ Using TorchScript model directly")
352
+ return True
353
+
354
+ except Exception as e:
355
+ print(f" ❌ TorchScript loading failed: {str(e)}")
356
+ return False
357
+
358
+ def _load_state_dict_direct(self, state_dict: Dict) -> bool:
359
+ """Helper to load state dict directly"""
360
+ try:
361
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
362
+ loaded_keys = len(state_dict) - len(missing_keys)
363
+ total_keys = len(self.model.state_dict())
364
+ load_percentage = (loaded_keys / total_keys) * 100
365
+
366
+ print(f" 📊 Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)")
367
+ return load_percentage > 50
368
+
369
+ except Exception as e:
370
+ print(f" ❌ Direct state dict loading failed: {str(e)}")
371
+ return False
372
+
373
  def setup_preprocessor(self):
374
  """Initialize audio preprocessor"""
375
  self.preprocessor = RespiratoryAudioPreprocessor()
376
  print("✅ Audio preprocessor initialized")
377
 
378
  def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]:
379
+ """Predict respiratory symptoms with enhanced threshold logic and health classification"""
380
  try:
381
  start_time = time.time()
382
 
 
388
  inference_start = time.time()
389
  with torch.no_grad():
390
  outputs = self.model(tensor_input)
 
391
  inference_time = time.time() - inference_start
392
 
393
  # Parse outputs
394
  probabilities = outputs['probabilities'].squeeze().numpy()
 
395
 
396
+ # ENHANCED THRESHOLD LOGIC with neutral detection
397
+ detected_symptoms = []
398
+
399
+ for i, symptom in enumerate(self.config['target_symptoms']):
400
+ prob = float(probabilities[i])
401
+ symptom_threshold = self.config['confidence_thresholds'][symptom]
402
+
403
+ # Apply dual threshold system:
404
+ # 1. Must be above symptom-specific threshold
405
+ # 2. Must be above neutral threshold to avoid false positives
406
+ effective_threshold = max(symptom_threshold, self.neutral_threshold)
407
+ is_detected = prob >= effective_threshold
408
+
409
+ if is_detected:
410
+ detected_symptoms.append({
411
+ 'symptom': symptom,
412
+ 'display_name': self.config['symptom_display_names'][symptom],
413
+ 'confidence': prob,
414
+ 'color': self.config['symptom_colors'][symptom],
415
+ 'threshold_used': effective_threshold
416
+ })
417
+
418
+ # ✅ DETERMINE OVERALL HEALTH STATUS
419
+ max_confidence = np.max(probabilities)
420
+
421
+ if not detected_symptoms:
422
+ if max_confidence < self.neutral_threshold:
423
+ health_status = "healthy"
424
+ status_message = "No symptoms detected - appears healthy"
425
+ else:
426
+ health_status = "inconclusive"
427
+ status_message = "Some patterns detected but below confidence threshold"
428
+ else:
429
+ health_status = "symptoms_detected"
430
+ status_message = f"{len(detected_symptoms)} symptom(s) detected"
431
+
432
+ # Format results with enhanced health classification
433
+ results = self.format_results_enhanced(
434
+ probabilities, detected_symptoms, health_status, status_message, max_confidence
435
+ )
436
 
437
+ # Add comprehensive processing info
438
  results['processing_info'] = {
439
  'preprocessing_time_ms': round(preprocessing_time * 1000, 1),
440
  'inference_time_ms': round(inference_time * 1000, 1),
441
  'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1),
442
+ 'model_weights_loaded': self.weights_loaded,
443
+ 'neutral_threshold': self.neutral_threshold,
444
+ 'max_confidence': round(max_confidence, 3)
445
  }
446
 
447
  return results
 
449
  except Exception as e:
450
  raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
451
 
452
+ def format_results_enhanced(self, probabilities, detected_symptoms, health_status, status_message, max_confidence):
453
+ """Enhanced results formatting with health classification"""
454
+
455
  results = {
456
+ 'detected_symptoms': detected_symptoms,
457
  'all_symptoms': {},
458
  'summary': {},
459
+ 'recommendations': [],
460
+ 'health_classification': health_status
461
  }
462
 
463
+ # Process all symptoms with enhanced threshold information
464
  for i, symptom in enumerate(self.config['target_symptoms']):
465
  prob = float(probabilities[i])
466
+ original_threshold = self.config['confidence_thresholds'][symptom]
467
+ effective_threshold = max(original_threshold, self.neutral_threshold)
468
+ detected = prob >= effective_threshold
469
 
 
470
  results['all_symptoms'][symptom] = {
471
+ 'display_name': self.config['symptom_display_names'][symptom],
472
  'confidence': prob,
473
+ 'detected': detected,
474
+ 'original_threshold': original_threshold,
475
+ 'effective_threshold': effective_threshold,
476
+ 'neutral_threshold': self.neutral_threshold,
477
  'color': self.config['symptom_colors'][symptom]
478
  }
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
+ # Enhanced summary with health classification
481
  results['summary'] = {
482
+ 'total_detected': len(detected_symptoms),
483
+ 'highest_confidence': max([s['confidence'] for s in detected_symptoms], default=0.0),
484
+ 'max_overall_confidence': max_confidence,
485
+ 'status': health_status,
486
+ 'status_message': status_message,
487
+ 'neutral_threshold': self.neutral_threshold,
488
+ 'weights_status': 'trained' if self.weights_loaded else 'random'
489
  }
490
 
491
+ # ENHANCED RECOMMENDATIONS based on health status
492
+ if health_status == "healthy":
493
+ results['recommendations'] = [
494
+ "✅ No significant respiratory symptoms detected",
495
+ "Your cough patterns appear normal and healthy",
496
+ "Continue maintaining good respiratory health practices",
497
+ "This screening is for informational purposes only"
498
+ ]
499
+ elif health_status == "inconclusive":
500
  results['recommendations'] = [
501
+ "⚠️ Some respiratory patterns detected but below confidence threshold",
502
+ "Consider monitoring your symptoms over the next few days",
503
+ "If symptoms persist or worsen, consult a healthcare provider",
504
+ "This AI screening should not replace professional medical advice"
505
  ]
506
+ elif len(detected_symptoms) == 1:
507
+ symptom_name = detected_symptoms[0]['display_name']
508
+ confidence = detected_symptoms[0]['confidence']
509
  results['recommendations'] = [
510
+ f"🔍 Detected: {symptom_name} (confidence: {confidence:.1%})",
511
+ "Monitor this symptom and note any changes or progression",
512
+ "Consider consulting a healthcare provider if symptoms persist or worsen",
513
+ "This AI screening should not replace professional medical advice"
514
  ]
515
  else:
516
+ symptom_names = [s['display_name'] for s in detected_symptoms]
517
  results['recommendations'] = [
518
+ f"🚨 Multiple symptoms detected: {', '.join(symptom_names)}",
519
+ "Multiple symptoms may indicate a need for medical attention",
520
+ "Please consult a healthcare provider for proper evaluation and diagnosis",
521
+ "This AI screening should not replace professional medical advice"
522
  ]
523
 
524
+ # Add model status warning if using random weights
525
+ if not self.weights_loaded:
526
+ results['recommendations'].insert(0,
527
+ "⚠️ DEVELOPMENT MODE: Model using random weights - results are not medically valid"
528
+ )
529
+
530
  return results
531
 
532
+ # Initialize service with enhanced error handling
533
+ print("🚀 Initializing Enhanced Respiratory Analysis Service...")
534
  try:
535
  service = RespiratoryAnalysisService()
536
  print("✅ Service initialized successfully!")
537
+ print(f" Model weights loaded: {'Yes' if service.weights_loaded else 'No (using random weights)'}")
538
+ print(f" Neutral threshold: {service.neutral_threshold}")
539
  except Exception as e:
540
  print(f"❌ Service initialization failed: {str(e)}")
541
  service = None
542
 
543
+ # =================== API ROUTES ===================
544
+
545
  @app.get("/")
546
  async def root():
547
+ """Root endpoint with enhanced API information"""
548
  if service is None:
549
  return {
550
  "service": "Respiratory Symptom Analysis API",
551
+ "version": "2.1.0",
552
  "status": "error - service not initialized"
553
  }
554
 
555
  return {
556
+ "service": "Respiratory Symptom Analysis API",
557
+ "version": "2.1.0",
558
  "status": "active",
559
+ "model_status": "trained_weights" if service.weights_loaded else "random_weights",
560
+ "health_classification": ["healthy", "symptoms_detected", "inconclusive"],
561
+ "neutral_threshold": service.neutral_threshold,
562
  "endpoints": {
563
  "analyze": "/analyze",
564
+ "health": "/health",
565
  "info": "/info",
566
  "docs": "/docs"
567
  },
568
+ "supported_symptoms": service.config['target_symptoms'],
569
  "model_info": {
570
  "version": service.config['model_version'],
571
+ "optimization": "CPU-optimized with health classification"
572
  }
573
  }
574
 
575
  @app.get("/health")
576
  async def health_check():
577
+ """Enhanced health check with detailed model status"""
578
+ model_files_status = {
579
+ "pytorch_state_dict": Path("optimized_model_cpu/model_pytorch_state_dict.pt").exists(),
580
+ "quantized_state_dict": Path("optimized_model_cpu/model_quantized_state_dict.pt").exists(),
581
+ "pytorch_full": Path("optimized_model_cpu/model_pytorch.pt").exists(),
582
+ "quantized_full": Path("optimized_model_cpu/model_quantized.pt").exists(),
583
+ "torchscript": Path("optimized_model_cpu/model_torchscript.pt").exists(),
584
+ "config": Path("optimized_model_cpu/model_config.json").exists()
585
+ }
586
+
587
  return {
588
  "status": "healthy" if service is not None else "unhealthy",
589
  "timestamp": time.time(),
590
+ "service_ready": service is not None,
591
  "model_loaded": service.model is not None if service else False,
592
  "config_loaded": service.config is not None if service else False,
593
+ "model_weights_status": "trained" if (service and service.weights_loaded) else "random",
594
+ "neutral_threshold": service.neutral_threshold if service else None,
595
+ "health_classification_enabled": True,
596
+ "model_files_available": model_files_status,
597
+ "files_found": sum(model_files_status.values()),
598
+ "critical_files_missing": not (model_files_status["config"] and
599
+ any([model_files_status["pytorch_state_dict"],
600
+ model_files_status["quantized_state_dict"],
601
+ model_files_status["pytorch_full"]]))
602
  }
603
 
604
  @app.get("/info")
605
  async def get_info():
606
+ """Get comprehensive model and service information"""
607
  if service is None:
608
  return {"error": "Service not initialized"}
609
 
610
  return {
611
  "model_info": {
612
+ "version": service.config.get('model_version', '2.1'),
613
  "target_symptoms": service.config['target_symptoms'],
614
  "symptom_display_names": service.config['symptom_display_names'],
615
+ "confidence_thresholds": service.config['confidence_thresholds'],
616
+ "weights_loaded": service.weights_loaded,
617
+ "neutral_threshold": service.neutral_threshold,
618
+ "health_classifications": ["healthy", "symptoms_detected", "inconclusive"]
619
  },
620
  "preprocessing_info": service.preprocessor.get_preprocessing_info(),
621
+ "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a", "webm"],
622
+ "max_duration": "30 seconds",
623
+ "max_file_size": "10MB",
624
+ "api_version": "2.1.0",
625
+ "features": {
626
+ "health_classification": True,
627
+ "neutral_detection": True,
628
+ "dual_threshold_system": True,
629
+ "trained_weights": service.weights_loaded
630
+ }
631
  }
632
 
633
  @app.post("/analyze")
634
  async def analyze_audio(audio_file: UploadFile = File(...)):
635
+ """
636
+ Enhanced audio analysis with health classification
637
+
638
+ Returns:
639
+ - Detected symptoms with confidence scores
640
+ - Health classification (healthy/symptoms_detected/inconclusive)
641
+ - Enhanced recommendations based on health status
642
+ - Model weight status for debugging
643
+ """
644
  if service is None:
645
  raise HTTPException(status_code=503, detail="Service not available")
646
 
647
+ # Validate file type (including WebM for browser recordings)
648
+ allowed_types = [
649
+ 'audio/wav', 'audio/mpeg', 'audio/mp3', 'audio/flac',
650
+ 'audio/ogg', 'audio/x-m4a', 'audio/mp4', 'audio/webm'
651
+ ]
652
+
653
  if audio_file.content_type not in allowed_types:
654
  raise HTTPException(
655
+ status_code=400,
656
+ detail=f"Unsupported format: {audio_file.content_type}. Supported: {', '.join(allowed_types)}"
657
  )
658
 
659
+ # Validate file size
660
  content = await audio_file.read()
661
+ max_size = 10 * 1024 * 1024 # 10MB
662
+ if len(content) > max_size:
663
+ raise HTTPException(status_code=400, detail="File too large. Maximum size: 10MB")
664
 
665
  try:
666
+ # Save uploaded file temporarily
667
  file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav'
668
  with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file:
669
  temp_file.write(content)
670
  temp_file_path = temp_file.name
671
 
672
+ # Analyze audio with enhanced health classification
673
  results = service.predict_symptoms(temp_file_path)
674
 
675
+ # Clean up temporary file
676
  os.unlink(temp_file_path)
677
 
678
+ # Return enhanced results
679
  return JSONResponse(
680
  status_code=200,
681
  content={
 
684
  "metadata": {
685
  "filename": audio_file.filename,
686
  "file_size_bytes": len(content),
687
+ "content_type": audio_file.content_type,
688
+ "timestamp": time.time(),
689
+ "api_version": "2.1.0"
690
  }
691
  }
692
  )
 
694
  except HTTPException:
695
  raise
696
  except Exception as e:
697
+ # Clean up temporary file if exists
698
  if 'temp_file_path' in locals():
699
  try:
700
  os.unlink(temp_file_path)
701
  except:
702
  pass
703
+
704
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
705
 
706
+ # Global exception handler
707
+ @app.exception_handler(Exception)
708
+ async def global_exception_handler(request, exc):
709
+ """Global exception handler with detailed error information"""
710
+ return JSONResponse(
711
+ status_code=500,
712
+ content={
713
+ "success": False,
714
+ "error": "Internal server error",
715
+ "detail": str(exc),
716
+ "model_status": "trained_weights" if (service and service.weights_loaded) else "random_weights",
717
+ "timestamp": time.time()
718
+ }
719
+ )
720
+
721
  if __name__ == "__main__":
722
  import uvicorn
723
+
724
+ # Run the API server
725
+ uvicorn.run(
726
+ "main:app",
727
+ host="0.0.0.0",
728
+ port=7860,
729
+ reload=False
730
+ )