Kalpokoch commited on
Commit
aaa4ee3
·
1 Parent(s): eb86f9c

updated backend

Browse files
Files changed (3) hide show
  1. audio_preprocessing.py +42 -24
  2. main.py +286 -456
  3. requirements.txt +26 -20
audio_preprocessing.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Audio Preprocessing Module for Respiratory Symptom Analysis
3
- Version without external resampling dependencies (resampy-free)
 
4
  """
5
 
6
  import librosa
@@ -16,13 +17,13 @@ from scipy import signal
16
  os.environ['NUMBA_CACHE_DIR'] = '/tmp'
17
  os.environ['NUMBA_DISABLE_JIT'] = '0'
18
 
19
- # Disable warnings
20
  warnings.filterwarnings('ignore')
21
 
 
22
  class RespiratoryAudioPreprocessor:
23
  """
24
- Audio preprocessor without external resampling dependencies
25
- Uses scipy.signal for resampling instead of resampy
26
  """
27
 
28
  def __init__(self,
@@ -35,8 +36,8 @@ class RespiratoryAudioPreprocessor:
35
  fmin: float = 0.0,
36
  fmax: float = None,
37
  power: float = 2.0,
38
- duration: float = 3.0):
39
- """Initialize preprocessing parameters"""
40
  self.target_sr = target_sr
41
  self.n_mels = n_mels
42
  self.n_fft = n_fft
@@ -49,8 +50,8 @@ class RespiratoryAudioPreprocessor:
49
  self.duration = duration
50
  self.target_length = int(target_sr * duration)
51
 
52
- # Expected output shape
53
- self.expected_shape = (1, 1, 128, 251)
54
 
55
  # Pre-warm librosa
56
  self._warmup_librosa()
@@ -190,14 +191,14 @@ class RespiratoryAudioPreprocessor:
190
  raise RuntimeError(f"Failed to load audio: {str(e)}")
191
 
192
  def extract_mel_spectrogram(self, audio_data: np.ndarray) -> np.ndarray:
193
- """Extract mel spectrogram without resampling dependencies"""
194
  try:
195
  # Ensure proper format
196
  audio_data = np.asarray(audio_data, dtype=np.float32)
197
  if len(audio_data.shape) > 1:
198
  audio_data = audio_data.flatten()
199
 
200
- # Extract mel spectrogram
201
  try:
202
  mel_spec = librosa.feature.melspectrogram(
203
  y=audio_data,
@@ -232,7 +233,7 @@ class RespiratoryAudioPreprocessor:
232
  raise RuntimeError(f"Failed to extract mel spectrogram: {str(e)}")
233
 
234
  def normalize_spectrogram(self, mel_spec: np.ndarray) -> np.ndarray:
235
- """Normalize spectrogram"""
236
  try:
237
  mean = np.mean(mel_spec)
238
  std = np.std(mel_spec)
@@ -242,21 +243,30 @@ class RespiratoryAudioPreprocessor:
242
  else:
243
  normalized = (mel_spec - mean) / (std + 1e-8)
244
 
 
245
  normalized = np.clip(normalized, -5.0, 5.0)
246
  return normalized
247
 
248
  except Exception as e:
249
  raise RuntimeError(f"Failed to normalize spectrogram: {str(e)}")
250
 
251
- def resize_spectrogram(self, mel_spec: np.ndarray, target_width: int = 251) -> np.ndarray:
252
- """Resize spectrogram to target dimensions"""
 
 
253
  try:
254
  current_height, current_width = mel_spec.shape
255
 
 
 
 
 
 
256
  if current_width == target_width:
257
  return mel_spec
258
 
259
  if current_width < target_width:
 
260
  pad_width = target_width - current_width
261
  mel_spec = np.pad(
262
  mel_spec,
@@ -265,6 +275,7 @@ class RespiratoryAudioPreprocessor:
265
  constant_values=0
266
  )
267
  else:
 
268
  mel_spec = mel_spec[:, :target_width]
269
 
270
  return mel_spec
@@ -273,26 +284,31 @@ class RespiratoryAudioPreprocessor:
273
  raise RuntimeError(f"Failed to resize spectrogram: {str(e)}")
274
 
275
  def preprocess_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> torch.Tensor:
276
- """Complete preprocessing pipeline"""
 
 
 
277
  try:
278
  # Load audio
279
  audio_data = self.load_and_normalize_audio(audio_input)
280
 
281
- # Extract features
282
  mel_spec = self.extract_mel_spectrogram(audio_data)
283
 
284
  # Normalize
285
  mel_spec_norm = self.normalize_spectrogram(mel_spec)
286
 
287
- # Resize
288
- mel_spec_resized = self.resize_spectrogram(mel_spec_norm)
289
 
290
- # Convert to tensor
291
  tensor_input = torch.FloatTensor(mel_spec_resized)
292
  tensor_input = tensor_input.unsqueeze(0).unsqueeze(0)
293
 
294
- # Fix shape if needed
295
  if tensor_input.shape != self.expected_shape:
 
 
296
  tensor_input = torch.nn.functional.interpolate(
297
  tensor_input,
298
  size=self.expected_shape[2:],
@@ -306,7 +322,7 @@ class RespiratoryAudioPreprocessor:
306
  raise RuntimeError(f"Preprocessing failed: {str(e)}")
307
 
308
  def get_preprocessing_info(self) -> Dict:
309
- """Get preprocessing info"""
310
  return {
311
  'target_sr': self.target_sr,
312
  'n_mels': self.n_mels,
@@ -314,11 +330,13 @@ class RespiratoryAudioPreprocessor:
314
  'hop_length': self.hop_length,
315
  'duration': self.duration,
316
  'output_shape': self.expected_shape,
317
- 'resampling_method': 'scipy.signal'
 
 
318
  }
319
 
320
  def validate_audio_file(self, audio_path: str) -> Tuple[bool, str]:
321
- """Validate audio file"""
322
  try:
323
  if not audio_path:
324
  return False, "No audio file provided"
@@ -328,9 +346,9 @@ class RespiratoryAudioPreprocessor:
328
  duration = info.duration
329
 
330
  if duration < 0.5:
331
- return False, f"Audio too short ({duration:.1f}s)"
332
  if duration > 30.0:
333
- return False, f"Audio too long ({duration:.1f}s)"
334
 
335
  return True, "Audio file is valid"
336
 
 
1
  """
2
  Audio Preprocessing Module for Respiratory Symptom Analysis
3
+ Updated for 39% F1-Macro Model (128x431 mel-spectrograms)
4
+ Version: 3.0.0
5
  """
6
 
7
  import librosa
 
17
  os.environ['NUMBA_CACHE_DIR'] = '/tmp'
18
  os.environ['NUMBA_DISABLE_JIT'] = '0'
19
 
 
20
  warnings.filterwarnings('ignore')
21
 
22
+
23
  class RespiratoryAudioPreprocessor:
24
  """
25
+ Audio preprocessor matching your 39% F1-Macro training pipeline
26
+ Mel-spectrogram shape: (128, 431) to match training data
27
  """
28
 
29
  def __init__(self,
 
36
  fmin: float = 0.0,
37
  fmax: float = None,
38
  power: float = 2.0,
39
+ duration: float = 10.0): # Changed from 3.0 to 10.0 to match training
40
+ """Initialize preprocessing parameters to match training"""
41
  self.target_sr = target_sr
42
  self.n_mels = n_mels
43
  self.n_fft = n_fft
 
50
  self.duration = duration
51
  self.target_length = int(target_sr * duration)
52
 
53
+ # Expected output shape - UPDATED to match training (128, 431)
54
+ self.expected_shape = (1, 1, 128, 431)
55
 
56
  # Pre-warm librosa
57
  self._warmup_librosa()
 
191
  raise RuntimeError(f"Failed to load audio: {str(e)}")
192
 
193
  def extract_mel_spectrogram(self, audio_data: np.ndarray) -> np.ndarray:
194
+ """Extract mel spectrogram matching training configuration"""
195
  try:
196
  # Ensure proper format
197
  audio_data = np.asarray(audio_data, dtype=np.float32)
198
  if len(audio_data.shape) > 1:
199
  audio_data = audio_data.flatten()
200
 
201
+ # Extract mel spectrogram with exact training parameters
202
  try:
203
  mel_spec = librosa.feature.melspectrogram(
204
  y=audio_data,
 
233
  raise RuntimeError(f"Failed to extract mel spectrogram: {str(e)}")
234
 
235
  def normalize_spectrogram(self, mel_spec: np.ndarray) -> np.ndarray:
236
+ """Normalize spectrogram to match training"""
237
  try:
238
  mean = np.mean(mel_spec)
239
  std = np.std(mel_spec)
 
243
  else:
244
  normalized = (mel_spec - mean) / (std + 1e-8)
245
 
246
+ # Clip to prevent extreme values
247
  normalized = np.clip(normalized, -5.0, 5.0)
248
  return normalized
249
 
250
  except Exception as e:
251
  raise RuntimeError(f"Failed to normalize spectrogram: {str(e)}")
252
 
253
+ def resize_spectrogram(self, mel_spec: np.ndarray, target_width: int = 431) -> np.ndarray:
254
+ """
255
+ Resize spectrogram to target dimensions (128, 431) to match training
256
+ """
257
  try:
258
  current_height, current_width = mel_spec.shape
259
 
260
+ # Handle height (should be 128 already)
261
+ if current_height != 128:
262
+ print(f"⚠️ Unexpected height: {current_height}, expected 128")
263
+
264
+ # Handle width
265
  if current_width == target_width:
266
  return mel_spec
267
 
268
  if current_width < target_width:
269
+ # Pad to target width
270
  pad_width = target_width - current_width
271
  mel_spec = np.pad(
272
  mel_spec,
 
275
  constant_values=0
276
  )
277
  else:
278
+ # Crop to target width
279
  mel_spec = mel_spec[:, :target_width]
280
 
281
  return mel_spec
 
284
  raise RuntimeError(f"Failed to resize spectrogram: {str(e)}")
285
 
286
  def preprocess_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> torch.Tensor:
287
+ """
288
+ Complete preprocessing pipeline matching your training
289
+ Output: (1, 1, 128, 431) tensor
290
+ """
291
  try:
292
  # Load audio
293
  audio_data = self.load_and_normalize_audio(audio_input)
294
 
295
+ # Extract mel-spectrogram
296
  mel_spec = self.extract_mel_spectrogram(audio_data)
297
 
298
  # Normalize
299
  mel_spec_norm = self.normalize_spectrogram(mel_spec)
300
 
301
+ # Resize to (128, 431)
302
+ mel_spec_resized = self.resize_spectrogram(mel_spec_norm, target_width=431)
303
 
304
+ # Convert to tensor (1, 1, 128, 431)
305
  tensor_input = torch.FloatTensor(mel_spec_resized)
306
  tensor_input = tensor_input.unsqueeze(0).unsqueeze(0)
307
 
308
+ # Verify shape
309
  if tensor_input.shape != self.expected_shape:
310
+ print(f"⚠️ Shape mismatch: got {tensor_input.shape}, expected {self.expected_shape}")
311
+ # Force resize using interpolation as last resort
312
  tensor_input = torch.nn.functional.interpolate(
313
  tensor_input,
314
  size=self.expected_shape[2:],
 
322
  raise RuntimeError(f"Preprocessing failed: {str(e)}")
323
 
324
  def get_preprocessing_info(self) -> Dict:
325
+ """Get preprocessing configuration info"""
326
  return {
327
  'target_sr': self.target_sr,
328
  'n_mels': self.n_mels,
 
330
  'hop_length': self.hop_length,
331
  'duration': self.duration,
332
  'output_shape': self.expected_shape,
333
+ 'resampling_method': 'scipy.signal',
334
+ 'normalization': 'z-score (mean=0, std=1)',
335
+ 'db_scale': True
336
  }
337
 
338
  def validate_audio_file(self, audio_path: str) -> Tuple[bool, str]:
339
+ """Validate audio file before processing"""
340
  try:
341
  if not audio_path:
342
  return False, "No audio file provided"
 
346
  duration = info.duration
347
 
348
  if duration < 0.5:
349
+ return False, f"Audio too short ({duration:.1f}s). Minimum: 0.5s"
350
  if duration > 30.0:
351
+ return False, f"Audio too long ({duration:.1f}s). Maximum: 30s"
352
 
353
  return True, "Audio file is valid"
354
 
main.py CHANGED
@@ -1,7 +1,8 @@
1
  """
2
  FastAPI Backend for Respiratory Symptom Analysis
3
- Updated with proper model weight loading and health classification system
4
  Deployed on HuggingFace Spaces for use with Netlify frontend
 
5
  """
6
 
7
  from fastapi import FastAPI, File, UploadFile, HTTPException
@@ -17,142 +18,132 @@ from pathlib import Path
17
  from typing import Dict, List, Any
18
  import time
19
  import warnings
20
- import copy
21
 
22
  # Import your preprocessing module
23
  from audio_preprocessing import RespiratoryAudioPreprocessor
24
 
25
  warnings.filterwarnings('ignore')
26
 
27
- def convert_numpy_types(obj):
28
- """Convert numpy types to native Python types for JSON serialization"""
29
- if isinstance(obj, np.integer):
30
- return int(obj)
31
- elif isinstance(obj, np.floating):
32
- return float(obj)
33
- elif isinstance(obj, np.ndarray):
34
- return obj.tolist()
35
- elif isinstance(obj, dict):
36
- return {key: convert_numpy_types(value) for key, value in obj.items()}
37
- elif isinstance(obj, list):
38
- return [convert_numpy_types(item) for item in obj]
39
- return obj
40
-
41
- # =================== MODEL ARCHITECTURE (Recreated for Loading) ===================
42
- class PurePyTorchInferenceModel(nn.Module):
43
  """
44
- Pure PyTorch model for inference - recreated to fix loading issues
 
 
45
  """
46
- def __init__(self, target_symptoms, confidence_thresholds):
47
  super().__init__()
 
48
 
49
- self.target_symptoms = target_symptoms
50
- self.num_symptoms = len(target_symptoms)
51
-
52
- # Enhanced feature extractor (matching your training)
53
- self.feature_extractor = nn.Sequential(
54
- # Block 1: Fine-grained frequency analysis
55
- nn.Conv2d(1, 32, kernel_size=(7, 3), stride=(2, 1), padding=(3, 1)),
56
  nn.BatchNorm2d(32),
57
- nn.ReLU(inplace=True),
58
- nn.MaxPool2d((2, 2)),
59
- nn.Dropout2d(0.1),
60
-
61
- # Block 2: Temporal pattern capture
62
- nn.Conv2d(32, 64, kernel_size=(3, 7), stride=(1, 2), padding=(1, 3)),
 
 
 
63
  nn.BatchNorm2d(64),
64
- nn.ReLU(inplace=True),
65
- nn.MaxPool2d((2, 2)),
66
- nn.Dropout2d(0.15),
67
-
68
- # Block 3: Mixed spatio-temporal patterns
69
- nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 
 
 
70
  nn.BatchNorm2d(128),
71
- nn.ReLU(inplace=True),
72
- nn.MaxPool2d((2, 2)),
73
- nn.Dropout2d(0.2),
74
-
75
- # Block 4: High-level feature extraction
76
- nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
77
- nn.BatchNorm2d(256),
78
- nn.ReLU(inplace=True),
79
-
80
- # Block 5: Deep feature refinement
81
- nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
82
- nn.BatchNorm2d(512),
83
- nn.ReLU(inplace=True),
84
- nn.AdaptiveAvgPool2d((1, 1))
85
  )
86
 
87
- # Shared feature processing
88
- self.shared_features = nn.Sequential(
89
- nn.Flatten(),
90
- nn.Dropout(0.4),
91
- nn.Linear(512, 512),
92
- nn.BatchNorm1d(512),
93
- nn.ReLU(inplace=True),
94
- nn.Dropout(0.3)
95
  )
96
 
97
- # Simple attention mechanism (no MultiheadAttention to avoid TorchScript issues)
98
- self.attention = nn.Sequential(
99
- nn.Linear(512, 256),
100
- nn.ReLU(inplace=True),
101
- nn.Linear(256, 512),
102
- nn.Sigmoid()
 
 
103
  )
104
 
105
- # Individual symptom-specific heads
106
  self.symptom_heads = nn.ModuleList([
107
- nn.Sequential(
108
- nn.Linear(512, 128),
109
- nn.ReLU(inplace=True),
110
- nn.Dropout(0.2),
111
- nn.Linear(128, 64),
112
- nn.ReLU(inplace=True),
113
- nn.Linear(64, 1)
114
- ) for _ in range(self.num_symptoms)
115
  ])
116
-
117
- # Convert thresholds to tensor
118
- self.register_buffer('threshold_tensor',
119
- torch.tensor([confidence_thresholds[symptom]
120
- for symptom in target_symptoms], dtype=torch.float32))
121
 
122
  def forward(self, x):
123
- """Forward pass for inference"""
124
- # Extract features
125
- features = self.feature_extractor(x)
126
- shared_features = self.shared_features(features)
127
 
128
- # Apply simple attention
129
- attention_weights = self.attention(shared_features)
130
- attended_features = shared_features * attention_weights
131
 
132
- # Individual symptom predictions
133
- symptom_logits = []
134
  for head in self.symptom_heads:
135
- logit = head(attended_features)
136
- symptom_logits.append(logit)
137
- symptom_logits = torch.cat(symptom_logits, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Convert to probabilities
140
- symptom_probs = torch.sigmoid(symptom_logits)
141
 
142
- # Apply thresholds
143
- symptom_preds = (symptom_probs >= self.threshold_tensor).float()
144
 
145
  return {
146
- 'probabilities': symptom_probs,
147
- 'predictions': symptom_preds,
148
- 'logits': symptom_logits
149
  }
150
 
 
151
  # Initialize FastAPI app
152
  app = FastAPI(
153
- title="🫁 Respiratory Symptom Analysis API",
154
- description="AI-powered respiratory symptom detection from cough audio",
155
- version="2.1.0",
156
  docs_url="/docs",
157
  redoc_url="/redoc"
158
  )
@@ -166,231 +157,159 @@ app.add_middleware(
166
  allow_headers=["*"],
167
  )
168
 
 
169
  class RespiratoryAnalysisService:
170
  """
171
- Enhanced service class for respiratory symptom analysis with proper model loading
172
  """
173
 
174
- def __init__(self, config_path: str = "optimized_model_cpu/model_config.json"):
175
  """Initialize the service with model and configuration"""
176
- self.config_path = config_path
177
  self.model = None
178
  self.config = None
179
  self.preprocessor = None
180
- self.weights_loaded = False # Track if real weights are loaded
181
- self.neutral_threshold = 0.35 # Below this = neutral/healthy
182
 
183
  # Load configuration and model
184
  self.load_config()
185
  self.create_and_load_model()
186
  self.setup_preprocessor()
187
-
188
  def load_config(self):
189
  """Load configuration"""
 
 
190
  try:
191
- if Path(self.config_path).exists():
192
- with open(self.config_path, 'r') as f:
193
  self.config = json.load(f)
194
- print(f"✅ Configuration loaded from {self.config_path}")
195
  else:
196
- # Default configuration if file not found
197
  self.config = {
198
- 'target_symptoms': ['fever', 'cold', 'sorethroat', 'lossofsmell', 'fatigue', 'cough'],
199
  'symptom_display_names': {
200
  'fever': 'Fever',
201
- 'cold': 'Cold/Runny Nose',
202
- 'sorethroat': 'Sore Throat',
203
- 'lossofsmell': 'Loss of Smell',
204
  'fatigue': 'Fatigue',
205
  'cough': 'Persistent Cough'
206
  },
207
  'confidence_thresholds': {
208
- 'fever': 0.42, 'cold': 0.39, 'sorethroat': 0.45,
209
- 'lossofsmell': 0.52, 'fatigue': 0.43, 'cough': 0.35
 
 
210
  },
211
  'symptom_colors': {
212
- 'fever': '#FF6B6B', 'cold': '#4ECDC4', 'sorethroat': '#45B7D1',
213
- 'lossofsmell': '#96CEB4', 'fatigue': '#FFEAA7', 'cough': '#DDA0DD'
 
 
214
  },
215
- 'model_version': '2.1',
216
- 'optimization_settings': {'torch_threads': 4}
 
217
  }
218
- print("⚠️ Using default configuration")
219
 
220
  except Exception as e:
221
  raise RuntimeError(f"Failed to load config: {str(e)}")
222
 
223
  def create_and_load_model(self):
224
- """Create model and try to load weights from available files with priority order"""
225
  try:
226
- # Create model with correct architecture
227
- self.model = PurePyTorchInferenceModel(
228
- target_symptoms=self.config['target_symptoms'],
229
- confidence_thresholds=self.config['confidence_thresholds']
230
  )
231
 
232
  print("🔍 Searching for model weight files...")
233
 
234
- # PRIORITY ORDER: Try different model files with detailed logging
235
  weight_files_to_try = [
236
- # Highest priority - state dicts (most compatible)
237
- ("optimized_model_cpu/model_pytorch_state_dict.pt", "PyTorch State Dict", "state_dict"),
238
- ("optimized_model_cpu/model_quantized_state_dict.pt", "Quantized State Dict", "state_dict"),
239
-
240
- # Medium priority - full models
241
- ("optimized_model_cpu/model_pytorch.pt", "Full PyTorch Model", "full_model"),
242
- ("optimized_model_cpu/model_quantized.pt", "Quantized PyTorch Model", "full_model"),
243
-
244
- # Lower priority - TorchScript (compatibility issues)
245
- ("optimized_model_cpu/model_torchscript.pt", "TorchScript Model", "torchscript"),
246
  ]
247
 
248
- for weight_file, model_type, load_type in weight_files_to_try:
249
- if Path(weight_file).exists():
250
- file_size = Path(weight_file).stat().st_size / (1024*1024) # Size in MB
251
  print(f"📁 Found {model_type}: {weight_file} ({file_size:.1f}MB)")
252
 
253
  try:
254
- if load_type == "state_dict":
255
- success = self._load_state_dict(weight_file, model_type)
256
- elif load_type == "full_model":
257
- success = self._load_full_model(weight_file, model_type)
258
- elif load_type == "torchscript":
259
- success = self._load_torchscript_model(weight_file, model_type)
 
 
 
 
260
  else:
261
- success = False
262
-
263
- if success:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  self.weights_loaded = True
265
- print(f"✅ Successfully loaded {model_type}")
266
  break
267
 
268
  except Exception as e:
269
- print(f"⚠️ Failed to load {model_type}: {str(e)}")
270
  continue
271
- else:
272
- print(f"❌ Not found: {weight_file}")
273
 
274
  if not self.weights_loaded:
275
  print("\n❌ WARNING: Using random model weights!")
276
- print("❌ All predictions will be random (~50% confidence)")
277
- print("❌ Please check your model files in optimized_model_cpu/")
278
- print("❌ Expected files:")
279
- for file_path, _, _ in weight_files_to_try:
280
- print(f" - {file_path}")
281
  else:
282
  print(f"✅ Model ready with trained weights")
283
 
284
- # Set model to evaluation mode
 
 
 
 
 
285
  self.model.eval()
286
 
287
- # Set CPU optimization
288
- torch.set_num_threads(self.config['optimization_settings'].get('torch_threads', 4))
289
 
290
  except Exception as e:
291
  raise RuntimeError(f"Failed to create/load model: {str(e)}")
292
 
293
- def _load_state_dict(self, weight_file: str, model_type: str) -> bool:
294
- """Load model from state dict file"""
295
- try:
296
- checkpoint = torch.load(weight_file, map_location='cpu')
297
-
298
- # Handle different checkpoint formats
299
- if isinstance(checkpoint, dict):
300
- if 'state_dict' in checkpoint:
301
- state_dict = checkpoint['state_dict']
302
- elif 'model_state_dict' in checkpoint:
303
- state_dict = checkpoint['model_state_dict']
304
- else:
305
- state_dict = checkpoint
306
- else:
307
- state_dict = checkpoint
308
-
309
- # Remove any incompatible keys
310
- filtered_state_dict = {}
311
- for key, value in state_dict.items():
312
- # Skip keys that might cause issues
313
- if any(skip in key for skip in ['symptom_attention', 'covid_classifier', 'aux_']):
314
- print(f" Skipping incompatible key: {key}")
315
- continue
316
- filtered_state_dict[key] = value
317
-
318
- # Load weights
319
- missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)
320
-
321
- # Check if enough weights were loaded
322
- loaded_keys = len(filtered_state_dict) - len(missing_keys)
323
- total_keys = len(self.model.state_dict())
324
- load_percentage = (loaded_keys / total_keys) * 100
325
-
326
- print(f" 📊 Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)")
327
-
328
- if missing_keys:
329
- print(f" ⚠️ Missing keys: {len(missing_keys)} (using random initialization)")
330
- if unexpected_keys:
331
- print(f" ⚠️ Unexpected keys: {len(unexpected_keys)} (ignored)")
332
-
333
- # Consider successful if we loaded most parameters
334
- return load_percentage > 50
335
-
336
- except Exception as e:
337
- print(f" ❌ State dict loading failed: {str(e)}")
338
- return False
339
-
340
- def _load_full_model(self, weight_file: str, model_type: str) -> bool:
341
- """Load full model file"""
342
- try:
343
- loaded_model = torch.load(weight_file, map_location='cpu')
344
-
345
- if hasattr(loaded_model, 'state_dict'):
346
- # Extract state dict from full model
347
- state_dict = loaded_model.state_dict()
348
- return self._load_state_dict_direct(state_dict)
349
- else:
350
- # Try to use as state dict directly
351
- return self._load_state_dict_direct(loaded_model)
352
-
353
- except Exception as e:
354
- print(f" ❌ Full model loading failed: {str(e)}")
355
- return False
356
-
357
- def _load_torchscript_model(self, weight_file: str, model_type: str) -> bool:
358
- """Load TorchScript model (with known compatibility issues)"""
359
- try:
360
- scripted_model = torch.jit.load(weight_file, map_location='cpu')
361
- scripted_model.eval()
362
-
363
- # Replace the model entirely with TorchScript version
364
- self.model = scripted_model
365
- print(f" ✅ Using TorchScript model directly")
366
- return True
367
-
368
- except Exception as e:
369
- print(f" ❌ TorchScript loading failed: {str(e)}")
370
- return False
371
-
372
- def _load_state_dict_direct(self, state_dict: Dict) -> bool:
373
- """Helper to load state dict directly"""
374
- try:
375
- missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
376
- loaded_keys = len(state_dict) - len(missing_keys)
377
- total_keys = len(self.model.state_dict())
378
- load_percentage = (loaded_keys / total_keys) * 100
379
-
380
- print(f" 📊 Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)")
381
- return load_percentage > 50
382
-
383
- except Exception as e:
384
- print(f" ❌ Direct state dict loading failed: {str(e)}")
385
- return False
386
-
387
  def setup_preprocessor(self):
388
  """Initialize audio preprocessor"""
389
  self.preprocessor = RespiratoryAudioPreprocessor()
390
  print("✅ Audio preprocessor initialized")
391
 
392
  def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]:
393
- """Predict respiratory symptoms with enhanced threshold logic and health classification"""
394
  try:
395
  start_time = time.time()
396
 
@@ -405,215 +324,172 @@ class RespiratoryAnalysisService:
405
  inference_time = time.time() - inference_start
406
 
407
  # Parse outputs
408
- probabilities = outputs['probabilities'].squeeze().detach().cpu().numpy().astype(float)
409
 
410
- # ENHANCED THRESHOLD LOGIC with neutral detection
411
- detected_symptoms = []
412
 
 
 
413
  for i, symptom in enumerate(self.config['target_symptoms']):
414
  prob = float(probabilities[i])
415
- symptom_threshold = self.config['confidence_thresholds'][symptom]
 
416
 
417
- # Apply dual threshold system:
418
- # 1. Must be above symptom-specific threshold
419
- # 2. Must be above neutral threshold to avoid false positives
420
- effective_threshold = max(symptom_threshold, self.neutral_threshold)
421
- is_detected = prob >= effective_threshold
422
-
423
- if is_detected:
424
  detected_symptoms.append({
425
  'symptom': symptom,
426
  'display_name': self.config['symptom_display_names'][symptom],
427
- 'confidence': float(prob),
428
  'color': self.config['symptom_colors'][symptom],
429
- 'threshold_used': float(effective_threshold)
430
  })
431
 
432
- # DETERMINE OVERALL HEALTH STATUS
433
- max_confidence = np.max(probabilities)
434
 
435
  if not detected_symptoms:
436
  if max_confidence < self.neutral_threshold:
437
  health_status = "healthy"
438
  status_message = "No symptoms detected - appears healthy"
439
  else:
440
- health_status = "inconclusive"
441
  status_message = "Some patterns detected but below confidence threshold"
442
  else:
443
  health_status = "symptoms_detected"
444
  status_message = f"{len(detected_symptoms)} symptom(s) detected"
445
 
446
- # Format results with enhanced health classification
447
- results = self.format_results_enhanced(
448
- probabilities, detected_symptoms, health_status, status_message, max_confidence
449
- )
450
-
451
- # Add comprehensive processing info
452
- results['processing_info'] = {
453
- 'preprocessing_time_ms': round(preprocessing_time * 1000, 1),
454
- 'inference_time_ms': round(inference_time * 1000, 1),
455
- 'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1),
456
- 'model_weights_loaded': self.weights_loaded,
457
- 'neutral_threshold': self.neutral_threshold,
458
- 'max_confidence': round(max_confidence, 3)
 
 
 
 
 
 
 
 
 
459
  }
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  return results
462
 
463
  except Exception as e:
464
  raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
465
 
466
- def format_results_enhanced(self, probabilities, detected_symptoms, health_status, status_message, max_confidence):
467
- """Enhanced results formatting with health classification"""
468
-
469
- results = {
470
- 'detected_symptoms': detected_symptoms,
471
- 'all_symptoms': {},
472
- 'summary': {},
473
- 'recommendations': [],
474
- 'health_classification': health_status
475
- }
476
-
477
- # Process all symptoms with enhanced threshold information
478
- for i, symptom in enumerate(self.config['target_symptoms']):
479
- prob = float(probabilities[i]) # ✅ Convert to Python float
480
- original_threshold = float(self.config['confidence_thresholds'][symptom]) # ✅ Convert
481
- effective_threshold = float(max(original_threshold, self.neutral_threshold))
482
- detected = prob >= effective_threshold
483
-
484
- results['all_symptoms'][symptom] = {
485
- 'display_name': self.config['symptom_display_names'][symptom],
486
- 'confidence': prob,
487
- 'detected': bool(detected),
488
- 'original_threshold': original_threshold,
489
- 'effective_threshold': effective_threshold,
490
- 'neutral_threshold': float(self.neutral_threshold),
491
- 'color': self.config['symptom_colors'][symptom]
492
- }
493
 
494
- # Enhanced summary with health classification
495
- results['summary'] = {
496
- 'total_detected': int(len(detected_symptoms)),
497
- 'highest_confidence':float(max([s['confidence'] for s in detected_symptoms], default=0.0)),
498
- 'max_overall_confidence': float(max_confidence),
499
- 'status': str(health_status),
500
- 'status_message': str(status_message),
501
- 'neutral_threshold': float(self.neutral_threshold),
502
- 'weights_status': 'trained' if self.weights_loaded else 'random'
503
- }
504
 
505
- # ✅ ENHANCED RECOMMENDATIONS based on health status
506
  if health_status == "healthy":
507
- results['recommendations'] = [
508
  "✅ No significant respiratory symptoms detected",
509
  "Your cough patterns appear normal and healthy",
510
  "Continue maintaining good respiratory health practices",
511
  "This screening is for informational purposes only"
512
- ]
513
  elif health_status == "inconclusive":
514
- results['recommendations'] = [
515
  "⚠️ Some respiratory patterns detected but below confidence threshold",
516
- "Consider monitoring your symptoms over the next few days",
517
  "If symptoms persist or worsen, consult a healthcare provider",
518
  "This AI screening should not replace professional medical advice"
519
- ]
520
  elif len(detected_symptoms) == 1:
521
  symptom_name = detected_symptoms[0]['display_name']
522
  confidence = detected_symptoms[0]['confidence']
523
- results['recommendations'] = [
524
  f"🔍 Detected: {symptom_name} (confidence: {confidence:.1%})",
525
- "Monitor this symptom and note any changes or progression",
526
- "Consider consulting a healthcare provider if symptoms persist or worsen",
527
  "This AI screening should not replace professional medical advice"
528
- ]
529
  else:
530
  symptom_names = [s['display_name'] for s in detected_symptoms]
531
- results['recommendations'] = [
532
  f"🚨 Multiple symptoms detected: {', '.join(symptom_names)}",
533
  "Multiple symptoms may indicate a need for medical attention",
534
- "Please consult a healthcare provider for proper evaluation and diagnosis",
535
  "This AI screening should not replace professional medical advice"
536
- ]
537
 
538
- # Add model status warning if using random weights
539
- if not self.weights_loaded:
540
- results['recommendations'].insert(0,
541
- "⚠️ DEVELOPMENT MODE: Model using random weights - results are not medically valid"
542
- )
543
 
544
- def convert_numpy_types(obj):
545
- """Convert any remaining numpy types to Python types"""
546
- if hasattr(obj, 'item'): # numpy scalars
547
- return obj.item()
548
- elif isinstance(obj, np.integer):
549
- return int(obj)
550
- elif isinstance(obj, np.floating):
551
- return float(obj)
552
- elif isinstance(obj, np.ndarray):
553
- return obj.tolist()
554
- elif isinstance(obj, dict):
555
- return {key: convert_numpy_types(value) for key, value in obj.items()}
556
- elif isinstance(obj, list):
557
- return [convert_numpy_types(item) for item in obj]
558
- elif isinstance(obj, bool):
559
- return bool(obj)
560
- return obj
561
-
562
- return convert_numpy_types(results)
563
 
564
- # Initialize service with enhanced error handling
565
- print("🚀 Initializing Enhanced Respiratory Analysis Service...")
566
  try:
567
  service = RespiratoryAnalysisService()
568
  print("✅ Service initialized successfully!")
569
- print(f" Model weights loaded: {'Yes' if service.weights_loaded else 'No (using random weights)'}")
570
- print(f" Neutral threshold: {service.neutral_threshold}")
571
  except Exception as e:
572
  print(f"❌ Service initialization failed: {str(e)}")
573
  service = None
574
 
 
575
  # =================== API ROUTES ===================
576
 
577
  @app.get("/")
578
  async def root():
579
- """Root endpoint with enhanced API information"""
580
  if service is None:
581
- return {
582
- "service": "Respiratory Symptom Analysis API",
583
- "version": "2.1.0",
584
- "status": "error - service not initialized"
585
- }
586
 
587
  return {
588
- "service": "Respiratory Symptom Analysis API",
589
- "version": "2.1.0",
 
590
  "status": "active",
591
  "model_status": "trained_weights" if service.weights_loaded else "random_weights",
592
- "health_classification": ["healthy", "symptoms_detected", "inconclusive"],
593
- "neutral_threshold": service.neutral_threshold,
594
  "endpoints": {
595
  "analyze": "/analyze",
596
- "health": "/health",
597
  "info": "/info",
598
  "docs": "/docs"
599
- },
600
- "supported_symptoms": service.config['target_symptoms'],
601
- "model_info": {
602
- "version": service.config['model_version'],
603
- "optimization": "CPU-optimized with health classification"
604
  }
605
  }
606
 
 
607
  @app.get("/health")
608
  async def health_check():
609
- """Enhanced health check with detailed model status"""
610
  model_files_status = {
611
- "pytorch_state_dict": Path("optimized_model_cpu/model_pytorch_state_dict.pt").exists(),
612
- "quantized_state_dict": Path("optimized_model_cpu/model_quantized_state_dict.pt").exists(),
613
- "pytorch_full": Path("optimized_model_cpu/model_pytorch.pt").exists(),
614
- "quantized_full": Path("optimized_model_cpu/model_quantized.pt").exists(),
615
- "torchscript": Path("optimized_model_cpu/model_torchscript.pt").exists(),
616
- "config": Path("optimized_model_cpu/model_config.json").exists()
617
  }
618
 
619
  return {
@@ -621,93 +497,72 @@ async def health_check():
621
  "timestamp": time.time(),
622
  "service_ready": service is not None,
623
  "model_loaded": service.model is not None if service else False,
624
- "config_loaded": service.config is not None if service else False,
625
  "model_weights_status": "trained" if (service and service.weights_loaded) else "random",
626
- "neutral_threshold": service.neutral_threshold if service else None,
627
- "health_classification_enabled": True,
628
  "model_files_available": model_files_status,
629
- "files_found": sum(model_files_status.values()),
630
- "critical_files_missing": not (model_files_status["config"] and
631
- any([model_files_status["pytorch_state_dict"],
632
- model_files_status["quantized_state_dict"],
633
- model_files_status["pytorch_full"]]))
634
  }
635
 
 
636
  @app.get("/info")
637
  async def get_info():
638
- """Get comprehensive model and service information"""
639
  if service is None:
640
  return {"error": "Service not initialized"}
641
 
642
  return {
643
  "model_info": {
644
- "version": service.config.get('model_version', '2.1'),
 
645
  "target_symptoms": service.config['target_symptoms'],
646
  "symptom_display_names": service.config['symptom_display_names'],
647
  "confidence_thresholds": service.config['confidence_thresholds'],
648
  "weights_loaded": service.weights_loaded,
649
- "neutral_threshold": service.neutral_threshold,
650
- "health_classifications": ["healthy", "symptoms_detected", "inconclusive"]
651
  },
652
  "preprocessing_info": service.preprocessor.get_preprocessing_info(),
653
  "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a", "webm"],
654
  "max_duration": "30 seconds",
655
  "max_file_size": "10MB",
656
- "api_version": "2.1.0",
657
- "features": {
658
- "health_classification": True,
659
- "neutral_detection": True,
660
- "dual_threshold_system": True,
661
- "trained_weights": service.weights_loaded
662
- }
663
  }
664
 
 
665
  @app.post("/analyze")
666
  async def analyze_audio(audio_file: UploadFile = File(...)):
667
  """
668
- Enhanced audio analysis with health classification
669
 
670
- Returns:
671
- - Detected symptoms with confidence scores
672
- - Health classification (healthy/symptoms_detected/inconclusive)
673
- - Enhanced recommendations based on health status
674
- - Model weight status for debugging
675
  """
676
  if service is None:
677
  raise HTTPException(status_code=503, detail="Service not available")
678
 
679
- # Validate file type (including WebM for browser recordings)
680
- allowed_types = [
681
- 'audio/wav', 'audio/mpeg', 'audio/mp3', 'audio/flac',
682
- 'audio/ogg', 'audio/x-m4a', 'audio/mp4', 'audio/webm'
683
- ]
684
 
685
  if audio_file.content_type not in allowed_types:
686
- raise HTTPException(
687
- status_code=400,
688
- detail=f"Unsupported format: {audio_file.content_type}. Supported: {', '.join(allowed_types)}"
689
- )
690
 
691
  # Validate file size
692
  content = await audio_file.read()
693
- max_size = 10 * 1024 * 1024 # 10MB
694
- if len(content) > max_size:
695
- raise HTTPException(status_code=400, detail="File too large. Maximum size: 10MB")
696
 
697
  try:
698
- # Save uploaded file temporarily
699
  file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav'
700
  with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file:
701
  temp_file.write(content)
702
  temp_file_path = temp_file.name
703
 
704
- # Analyze audio with enhanced health classification
705
  results = service.predict_symptoms(temp_file_path)
706
 
707
- # Clean up temporary file
708
  os.unlink(temp_file_path)
709
 
710
- # Return enhanced results
711
  return JSONResponse(
712
  status_code=200,
713
  content={
@@ -718,45 +573,20 @@ async def analyze_audio(audio_file: UploadFile = File(...)):
718
  "file_size_bytes": len(content),
719
  "content_type": audio_file.content_type,
720
  "timestamp": time.time(),
721
- "api_version": "2.1.0"
722
  }
723
  }
724
  )
725
 
726
- except HTTPException:
727
- raise
728
  except Exception as e:
729
- # Clean up temporary file if exists
730
  if 'temp_file_path' in locals():
731
  try:
732
  os.unlink(temp_file_path)
733
  except:
734
  pass
735
-
736
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
737
 
738
- # Global exception handler
739
- @app.exception_handler(Exception)
740
- async def global_exception_handler(request, exc):
741
- """Global exception handler with detailed error information"""
742
- return JSONResponse(
743
- status_code=500,
744
- content={
745
- "success": False,
746
- "error": "Internal server error",
747
- "detail": str(exc),
748
- "model_status": "trained_weights" if (service and service.weights_loaded) else "random_weights",
749
- "timestamp": time.time()
750
- }
751
- )
752
 
753
  if __name__ == "__main__":
754
  import uvicorn
755
-
756
- # Run the API server
757
- uvicorn.run(
758
- "main:app",
759
- host="0.0.0.0",
760
- port=7860,
761
- reload=False
762
- )
 
1
  """
2
  FastAPI Backend for Respiratory Symptom Analysis
3
+ Updated for 39% F1-Macro Model (4 symptoms, no CBAM)
4
  Deployed on HuggingFace Spaces for use with Netlify frontend
5
+ Version: 3.0.0
6
  """
7
 
8
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
18
  from typing import Dict, List, Any
19
  import time
20
  import warnings
 
21
 
22
  # Import your preprocessing module
23
  from audio_preprocessing import RespiratoryAudioPreprocessor
24
 
25
  warnings.filterwarnings('ignore')
26
 
27
+ # =================== YOUR EXACT MODEL ARCHITECTURE ===================
28
+ class LightweightMultiSymptomClassifier(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
+ Exact model architecture from your 39% F1-Macro training
31
+ 4 symptoms: fever, cold, fatigue, cough
32
+ No CBAM, simplified CNN architecture
33
  """
34
+ def __init__(self, num_classes=4, dropout=0.5):
35
  super().__init__()
36
+ self.num_classes = num_classes
37
 
38
+ # Convolutional backbone
39
+ self.conv1 = nn.Sequential(
40
+ nn.Conv2d(1, 32, kernel_size=3, padding=1),
 
 
 
 
41
  nn.BatchNorm2d(32),
42
+ nn.ReLU(),
43
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
44
+ nn.BatchNorm2d(32),
45
+ nn.ReLU(),
46
+ nn.MaxPool2d(2)
47
+ )
48
+
49
+ self.conv2 = nn.Sequential(
50
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
51
  nn.BatchNorm2d(64),
52
+ nn.ReLU(),
53
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(),
56
+ nn.MaxPool2d(2)
57
+ )
58
+
59
+ self.conv3 = nn.Sequential(
60
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
61
  nn.BatchNorm2d(128),
62
+ nn.ReLU(),
63
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
64
+ nn.BatchNorm2d(128),
65
+ nn.ReLU(),
66
+ nn.MaxPool2d(2)
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
+ self.conv4 = nn.Sequential(
70
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
71
+ nn.BatchNorm2d(256),
72
+ nn.ReLU(),
73
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
74
+ nn.BatchNorm2d(256),
75
+ nn.ReLU(),
76
+ nn.AdaptiveAvgPool2d((1, 1))
77
  )
78
 
79
+ # Shared feature layer
80
+ self.shared_fc = nn.Sequential(
81
+ nn.Linear(256, 256),
82
+ nn.ReLU(),
83
+ nn.Dropout(dropout),
84
+ nn.Linear(256, 128),
85
+ nn.ReLU(),
86
+ nn.Dropout(dropout)
87
  )
88
 
89
+ # Individual symptom heads
90
  self.symptom_heads = nn.ModuleList([
91
+ nn.Linear(128, 1) for _ in range(num_classes)
 
 
 
 
 
 
 
92
  ])
 
 
 
 
 
93
 
94
  def forward(self, x):
95
+ x = self.conv1(x)
96
+ x = self.conv2(x)
97
+ x = self.conv3(x)
98
+ x = self.conv4(x)
99
 
100
+ x = x.view(x.size(0), -1)
101
+ shared_features = self.shared_fc(x)
 
102
 
103
+ outputs = []
 
104
  for head in self.symptom_heads:
105
+ outputs.append(head(shared_features))
106
+
107
+ logits = torch.cat(outputs, dim=1)
108
+ return logits
109
+
110
+
111
+ class OptimizedInferenceModel(nn.Module):
112
+ """
113
+ Inference wrapper with custom thresholds
114
+ """
115
+ def __init__(self, base_model, target_symptoms, confidence_thresholds):
116
+ super().__init__()
117
+ self.base_model = base_model
118
+ self.target_symptoms = target_symptoms
119
+
120
+ # Convert thresholds to tensor
121
+ self.register_buffer('threshold_tensor',
122
+ torch.tensor([confidence_thresholds[symptom]
123
+ for symptom in target_symptoms], dtype=torch.float32))
124
+
125
+ def forward(self, x):
126
+ # Get logits from base model
127
+ logits = self.base_model(x)
128
 
129
  # Convert to probabilities
130
+ probs = torch.sigmoid(logits)
131
 
132
+ # Apply custom thresholds
133
+ preds = (probs >= self.threshold_tensor).float()
134
 
135
  return {
136
+ 'probabilities': probs,
137
+ 'predictions': preds,
138
+ 'logits': logits
139
  }
140
 
141
+
142
  # Initialize FastAPI app
143
  app = FastAPI(
144
+ title="🫁 Respiratory Symptom Analysis API v3.0",
145
+ description="AI-powered respiratory symptom detection (39% F1-Macro model)",
146
+ version="3.0.0",
147
  docs_url="/docs",
148
  redoc_url="/redoc"
149
  )
 
157
  allow_headers=["*"],
158
  )
159
 
160
+
161
  class RespiratoryAnalysisService:
162
  """
163
+ Service class for respiratory symptom analysis with 39% F1-Macro model
164
  """
165
 
166
+ def __init__(self, model_dir: str = "deployment_model"):
167
  """Initialize the service with model and configuration"""
168
+ self.model_dir = Path(model_dir)
169
  self.model = None
170
  self.config = None
171
  self.preprocessor = None
172
+ self.weights_loaded = False
173
+ self.neutral_threshold = 0.35
174
 
175
  # Load configuration and model
176
  self.load_config()
177
  self.create_and_load_model()
178
  self.setup_preprocessor()
179
+
180
  def load_config(self):
181
  """Load configuration"""
182
+ config_path = self.model_dir / "model_config.json"
183
+
184
  try:
185
+ if config_path.exists():
186
+ with open(config_path, 'r') as f:
187
  self.config = json.load(f)
188
+ print(f"✅ Configuration loaded from {config_path}")
189
  else:
190
+ # Default configuration for 4-symptom model
191
  self.config = {
192
+ 'target_symptoms': ['fever', 'cold', 'fatigue', 'cough'],
193
  'symptom_display_names': {
194
  'fever': 'Fever',
195
+ 'cold': 'Cold/Runny Nose',
 
 
196
  'fatigue': 'Fatigue',
197
  'cough': 'Persistent Cough'
198
  },
199
  'confidence_thresholds': {
200
+ 'fever': 0.5,
201
+ 'cold': 0.5,
202
+ 'fatigue': 0.5,
203
+ 'cough': 0.5
204
  },
205
  'symptom_colors': {
206
+ 'fever': '#FF6B6B',
207
+ 'cold': '#4ECDC4',
208
+ 'fatigue': '#FFEAA7',
209
+ 'cough': '#DDA0DD'
210
  },
211
+ 'model_version': '3.0_39percent_f1',
212
+ 'num_classes': 4,
213
+ 'dropout': 0.5
214
  }
215
+ print("⚠️ Using default configuration")
216
 
217
  except Exception as e:
218
  raise RuntimeError(f"Failed to load config: {str(e)}")
219
 
220
  def create_and_load_model(self):
221
+ """Create model and load weights"""
222
  try:
223
+ # Create base model
224
+ base_model = LightweightMultiSymptomClassifier(
225
+ num_classes=self.config['num_classes'],
226
+ dropout=self.config['dropout']
227
  )
228
 
229
  print("🔍 Searching for model weight files...")
230
 
231
+ # Priority order for loading weights
232
  weight_files_to_try = [
233
+ (self.model_dir / "model_base.pt", "Base Model"),
234
+ (self.model_dir / "model_inference.pt", "Inference Model"),
235
+ (self.model_dir / "model_quantized.pt", "Quantized Model"),
236
+ (self.model_dir / "model_torchscript.pt", "TorchScript Model"),
237
+ (self.model_dir / "best_model.pt", "Best Checkpoint"),
 
 
 
 
 
238
  ]
239
 
240
+ for weight_file, model_type in weight_files_to_try:
241
+ if weight_file.exists():
242
+ file_size = weight_file.stat().st_size / (1024*1024)
243
  print(f"📁 Found {model_type}: {weight_file} ({file_size:.1f}MB)")
244
 
245
  try:
246
+ checkpoint = torch.load(weight_file, map_location='cpu', weights_only=False)
247
+
248
+ # Handle different checkpoint formats
249
+ if isinstance(checkpoint, dict):
250
+ if 'model_state_dict' in checkpoint:
251
+ state_dict = checkpoint['model_state_dict']
252
+ elif 'state_dict' in checkpoint:
253
+ state_dict = checkpoint['state_dict']
254
+ else:
255
+ state_dict = checkpoint
256
  else:
257
+ # TorchScript or full model
258
+ if hasattr(checkpoint, 'state_dict'):
259
+ state_dict = checkpoint.state_dict()
260
+ else:
261
+ # Use as TorchScript model directly
262
+ self.model = checkpoint
263
+ self.model.eval()
264
+ self.weights_loaded = True
265
+ print(f"✅ Loaded {model_type} (TorchScript)")
266
+ return
267
+
268
+ # Load state dict
269
+ missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
270
+
271
+ loaded_keys = len(state_dict) - len(missing)
272
+ total_keys = len(base_model.state_dict())
273
+ load_percentage = (loaded_keys / total_keys) * 100
274
+
275
+ print(f" 📊 Loaded {loaded_keys}/{total_keys} parameters ({load_percentage:.1f}%)")
276
+
277
+ if load_percentage > 50:
278
  self.weights_loaded = True
 
279
  break
280
 
281
  except Exception as e:
282
+ print(f"⚠️ Failed to load {model_type}: {str(e)}")
283
  continue
 
 
284
 
285
  if not self.weights_loaded:
286
  print("\n❌ WARNING: Using random model weights!")
287
+ print("❌ All predictions will be random")
288
+ print(f"❌ Expected model files in: {self.model_dir}/")
 
 
 
289
  else:
290
  print(f"✅ Model ready with trained weights")
291
 
292
+ # Wrap in inference model with thresholds
293
+ self.model = OptimizedInferenceModel(
294
+ base_model,
295
+ self.config['target_symptoms'],
296
+ self.config['confidence_thresholds']
297
+ )
298
  self.model.eval()
299
 
300
+ # CPU optimization
301
+ torch.set_num_threads(4)
302
 
303
  except Exception as e:
304
  raise RuntimeError(f"Failed to create/load model: {str(e)}")
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  def setup_preprocessor(self):
307
  """Initialize audio preprocessor"""
308
  self.preprocessor = RespiratoryAudioPreprocessor()
309
  print("✅ Audio preprocessor initialized")
310
 
311
  def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]:
312
+ """Predict respiratory symptoms"""
313
  try:
314
  start_time = time.time()
315
 
 
324
  inference_time = time.time() - inference_start
325
 
326
  # Parse outputs
327
+ probabilities = outputs['probabilities'].squeeze().detach().cpu().numpy()
328
 
329
+ # Convert numpy types to Python types
330
+ probabilities = probabilities.astype(float).tolist()
331
 
332
+ # Detect symptoms
333
+ detected_symptoms = []
334
  for i, symptom in enumerate(self.config['target_symptoms']):
335
  prob = float(probabilities[i])
336
+ threshold = float(self.config['confidence_thresholds'][symptom])
337
+ effective_threshold = max(threshold, self.neutral_threshold)
338
 
339
+ if prob >= effective_threshold:
 
 
 
 
 
 
340
  detected_symptoms.append({
341
  'symptom': symptom,
342
  'display_name': self.config['symptom_display_names'][symptom],
343
+ 'confidence': prob,
344
  'color': self.config['symptom_colors'][symptom],
345
+ 'threshold_used': effective_threshold
346
  })
347
 
348
+ # Determine health status
349
+ max_confidence = max(probabilities)
350
 
351
  if not detected_symptoms:
352
  if max_confidence < self.neutral_threshold:
353
  health_status = "healthy"
354
  status_message = "No symptoms detected - appears healthy"
355
  else:
356
+ health_status = "inconclusive"
357
  status_message = "Some patterns detected but below confidence threshold"
358
  else:
359
  health_status = "symptoms_detected"
360
  status_message = f"{len(detected_symptoms)} symptom(s) detected"
361
 
362
+ # Format results
363
+ results = {
364
+ 'detected_symptoms': detected_symptoms,
365
+ 'all_symptoms': {},
366
+ 'summary': {
367
+ 'total_detected': len(detected_symptoms),
368
+ 'highest_confidence': max([s['confidence'] for s in detected_symptoms], default=0.0),
369
+ 'max_overall_confidence': float(max_confidence),
370
+ 'status': health_status,
371
+ 'status_message': status_message,
372
+ 'neutral_threshold': float(self.neutral_threshold),
373
+ 'weights_status': 'trained' if self.weights_loaded else 'random'
374
+ },
375
+ 'recommendations': self._get_recommendations(health_status, detected_symptoms),
376
+ 'health_classification': health_status,
377
+ 'processing_info': {
378
+ 'preprocessing_time_ms': round(preprocessing_time * 1000, 1),
379
+ 'inference_time_ms': round(inference_time * 1000, 1),
380
+ 'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1),
381
+ 'model_weights_loaded': self.weights_loaded,
382
+ 'model_version': '3.0_39percent_f1'
383
+ }
384
  }
385
 
386
+ # Add all symptoms details
387
+ for i, symptom in enumerate(self.config['target_symptoms']):
388
+ prob = float(probabilities[i])
389
+ threshold = float(self.config['confidence_thresholds'][symptom])
390
+ effective_threshold = max(threshold, self.neutral_threshold)
391
+
392
+ results['all_symptoms'][symptom] = {
393
+ 'display_name': self.config['symptom_display_names'][symptom],
394
+ 'confidence': prob,
395
+ 'detected': prob >= effective_threshold,
396
+ 'original_threshold': threshold,
397
+ 'effective_threshold': effective_threshold,
398
+ 'color': self.config['symptom_colors'][symptom]
399
+ }
400
+
401
  return results
402
 
403
  except Exception as e:
404
  raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
405
 
406
+ def _get_recommendations(self, health_status, detected_symptoms):
407
+ """Generate recommendations based on health status"""
408
+ recommendations = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ if not self.weights_loaded:
411
+ recommendations.append("⚠️ DEVELOPMENT MODE: Model using random weights - results not valid")
 
 
 
 
 
 
 
 
412
 
 
413
  if health_status == "healthy":
414
+ recommendations.extend([
415
  "✅ No significant respiratory symptoms detected",
416
  "Your cough patterns appear normal and healthy",
417
  "Continue maintaining good respiratory health practices",
418
  "This screening is for informational purposes only"
419
+ ])
420
  elif health_status == "inconclusive":
421
+ recommendations.extend([
422
  "⚠️ Some respiratory patterns detected but below confidence threshold",
423
+ "Consider monitoring your symptoms over the next few days",
424
  "If symptoms persist or worsen, consult a healthcare provider",
425
  "This AI screening should not replace professional medical advice"
426
+ ])
427
  elif len(detected_symptoms) == 1:
428
  symptom_name = detected_symptoms[0]['display_name']
429
  confidence = detected_symptoms[0]['confidence']
430
+ recommendations.extend([
431
  f"🔍 Detected: {symptom_name} (confidence: {confidence:.1%})",
432
+ "Monitor this symptom and note any changes",
433
+ "Consider consulting a healthcare provider if symptoms persist",
434
  "This AI screening should not replace professional medical advice"
435
+ ])
436
  else:
437
  symptom_names = [s['display_name'] for s in detected_symptoms]
438
+ recommendations.extend([
439
  f"🚨 Multiple symptoms detected: {', '.join(symptom_names)}",
440
  "Multiple symptoms may indicate a need for medical attention",
441
+ "Please consult a healthcare provider for proper evaluation",
442
  "This AI screening should not replace professional medical advice"
443
+ ])
444
 
445
+ return recommendations
 
 
 
 
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
+ # Initialize service
449
+ print("🚀 Initializing Respiratory Analysis Service v3.0...")
450
  try:
451
  service = RespiratoryAnalysisService()
452
  print("✅ Service initialized successfully!")
453
+ print(f" Model: 39% F1-Macro (4 symptoms)")
454
+ print(f" Weights loaded: {'Yes' if service.weights_loaded else 'No'}")
455
  except Exception as e:
456
  print(f"❌ Service initialization failed: {str(e)}")
457
  service = None
458
 
459
+
460
  # =================== API ROUTES ===================
461
 
462
  @app.get("/")
463
  async def root():
464
+ """Root endpoint"""
465
  if service is None:
466
+ return {"service": "Respiratory Symptom Analysis API", "version": "3.0.0", "status": "error"}
 
 
 
 
467
 
468
  return {
469
+ "service": "Respiratory Symptom Analysis API",
470
+ "version": "3.0.0",
471
+ "model_version": "39% F1-Macro (4 symptoms)",
472
  "status": "active",
473
  "model_status": "trained_weights" if service.weights_loaded else "random_weights",
474
+ "supported_symptoms": service.config['target_symptoms'],
 
475
  "endpoints": {
476
  "analyze": "/analyze",
477
+ "health": "/health",
478
  "info": "/info",
479
  "docs": "/docs"
 
 
 
 
 
480
  }
481
  }
482
 
483
+
484
  @app.get("/health")
485
  async def health_check():
486
+ """Health check endpoint"""
487
  model_files_status = {
488
+ "model_base": (Path("deployment_model") / "model_base.pt").exists(),
489
+ "model_inference": (Path("deployment_model") / "model_inference.pt").exists(),
490
+ "model_quantized": (Path("deployment_model") / "model_quantized.pt").exists(),
491
+ "model_torchscript": (Path("deployment_model") / "model_torchscript.pt").exists(),
492
+ "config": (Path("deployment_model") / "model_config.json").exists()
 
493
  }
494
 
495
  return {
 
497
  "timestamp": time.time(),
498
  "service_ready": service is not None,
499
  "model_loaded": service.model is not None if service else False,
 
500
  "model_weights_status": "trained" if (service and service.weights_loaded) else "random",
 
 
501
  "model_files_available": model_files_status,
502
+ "api_version": "3.0.0"
 
 
 
 
503
  }
504
 
505
+
506
  @app.get("/info")
507
  async def get_info():
508
+ """Get model information"""
509
  if service is None:
510
  return {"error": "Service not initialized"}
511
 
512
  return {
513
  "model_info": {
514
+ "version": "3.0_39percent_f1",
515
+ "architecture": "LightweightMultiSymptomClassifier (no CBAM)",
516
  "target_symptoms": service.config['target_symptoms'],
517
  "symptom_display_names": service.config['symptom_display_names'],
518
  "confidence_thresholds": service.config['confidence_thresholds'],
519
  "weights_loaded": service.weights_loaded,
520
+ "neutral_threshold": service.neutral_threshold
 
521
  },
522
  "preprocessing_info": service.preprocessor.get_preprocessing_info(),
523
  "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a", "webm"],
524
  "max_duration": "30 seconds",
525
  "max_file_size": "10MB",
526
+ "api_version": "3.0.0"
 
 
 
 
 
 
527
  }
528
 
529
+
530
  @app.post("/analyze")
531
  async def analyze_audio(audio_file: UploadFile = File(...)):
532
  """
533
+ Analyze audio file for respiratory symptoms
534
 
535
+ Returns detected symptoms with confidence scores and health classification
 
 
 
 
536
  """
537
  if service is None:
538
  raise HTTPException(status_code=503, detail="Service not available")
539
 
540
+ # Validate file type
541
+ allowed_types = ['audio/wav', 'audio/mpeg', 'audio/mp3', 'audio/flac',
542
+ 'audio/ogg', 'audio/x-m4a', 'audio/mp4', 'audio/webm']
 
 
543
 
544
  if audio_file.content_type not in allowed_types:
545
+ raise HTTPException(status_code=400,
546
+ detail=f"Unsupported format: {audio_file.content_type}")
 
 
547
 
548
  # Validate file size
549
  content = await audio_file.read()
550
+ if len(content) > 10 * 1024 * 1024: # 10MB
551
+ raise HTTPException(status_code=400, detail="File too large. Maximum: 10MB")
 
552
 
553
  try:
554
+ # Save uploaded file temporarily
555
  file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav'
556
  with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file:
557
  temp_file.write(content)
558
  temp_file_path = temp_file.name
559
 
560
+ # Analyze audio
561
  results = service.predict_symptoms(temp_file_path)
562
 
563
+ # Clean up
564
  os.unlink(temp_file_path)
565
 
 
566
  return JSONResponse(
567
  status_code=200,
568
  content={
 
573
  "file_size_bytes": len(content),
574
  "content_type": audio_file.content_type,
575
  "timestamp": time.time(),
576
+ "api_version": "3.0.0"
577
  }
578
  }
579
  )
580
 
 
 
581
  except Exception as e:
 
582
  if 'temp_file_path' in locals():
583
  try:
584
  os.unlink(temp_file_path)
585
  except:
586
  pass
 
587
  raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
 
590
  if __name__ == "__main__":
591
  import uvicorn
592
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,30 +1,36 @@
1
  # FastAPI and web server dependencies
2
- fastapi>=0.104.0
3
- uvicorn[standard]>=0.24.0
4
- python-multipart>=0.0.6
5
 
6
- # PyTorch ecosystem (matching your training environment)
7
- torch>=1.13.0,<2.0.0
8
- torchvision>=0.14.0
9
- torchaudio>=0.13.0
 
10
 
11
- # Audio processing (with all dependencies)
12
- librosa>=0.9.2
13
- soundfile>=0.12.1
14
- resampy>=0.4.2
15
- audioread>=3.0.0
16
 
17
  # Core scientific computing
18
- numpy>=1.21.0,<2.0.0
19
- scipy>=1.7.0
20
- numba>=0.56.0
21
 
22
- # Additional audio processing dependencies
23
- llvmlite>=0.39.0
24
- pooch>=1.6.0
 
 
 
 
 
 
25
 
26
  # Data handling
27
- pandas>=1.3.0
28
 
29
  # System utilities
30
- packaging>=21.0
 
1
  # FastAPI and web server dependencies
2
+ fastapi==0.109.0
3
+ uvicorn[standard]==0.27.0
4
+ python-multipart==0.0.6
5
 
6
+ # PyTorch ecosystem - CPU-only for HuggingFace Spaces
7
+ --extra-index-url https://download.pytorch.org/whl/cpu
8
+ torch==2.1.0+cpu
9
+ torchvision==0.16.0+cpu
10
+ torchaudio==2.1.0+cpu
11
 
12
+ # Audio processing core libraries
13
+ librosa==0.10.1
14
+ soundfile==0.12.1
15
+ audioread==3.0.1
 
16
 
17
  # Core scientific computing
18
+ numpy==1.24.3
19
+ scipy==1.11.4
20
+ numba==0.58.1
21
 
22
+ # Audio processing dependencies
23
+ llvmlite==0.41.1
24
+ pooch==1.8.0
25
+ joblib==1.3.2
26
+ decorator==5.1.1
27
+ lazy-loader==0.3
28
+ msgpack==1.0.7
29
+ cffi==1.16.0
30
+ pycparser==2.21
31
 
32
  # Data handling
33
+ typing-extensions==4.9.0
34
 
35
  # System utilities
36
+ packaging==23.2