Kalpokoch commited on
Commit
00826e1
·
1 Parent(s): 4964982

added main.py

Browse files
Files changed (2) hide show
  1. audio_preprocessing.py +230 -0
  2. main.py +413 -0
audio_preprocessing.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Preprocessing Module for Respiratory Analysis
3
+ Matches the exact preprocessing used during training
4
+ """
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import torch
9
+ import warnings
10
+ from typing import Union, Tuple, Dict
11
+ import soundfile as sf
12
+
13
+ warnings.filterwarnings('ignore')
14
+
15
+ class RespiratoryAudioPreprocessor:
16
+ """
17
+ Audio preprocessor that matches training pipeline exactly
18
+ Converts raw audio files to mel-spectrograms for model inference
19
+ """
20
+
21
+ def __init__(self,
22
+ target_sr: int = 22050,
23
+ n_mels: int = 128,
24
+ n_fft: int = 2048,
25
+ hop_length: int = 512,
26
+ win_length: int = None,
27
+ window: str = 'hann',
28
+ fmin: float = 0.0,
29
+ fmax: float = None,
30
+ power: float = 2.0,
31
+ duration: float = 3.0): # 3 seconds max duration
32
+ """
33
+ Initialize preprocessing parameters to match training
34
+ """
35
+ self.target_sr = target_sr
36
+ self.n_mels = n_mels
37
+ self.n_fft = n_fft
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length
40
+ self.window = window
41
+ self.fmin = fmin
42
+ self.fmax = fmax or target_sr // 2
43
+ self.power = power
44
+ self.duration = duration
45
+ self.target_length = int(target_sr * duration) # 3 seconds worth of samples
46
+
47
+ # Expected output shape for your model
48
+ self.expected_shape = (1, 1, 128, 251) # (batch, channels, height, width)
49
+
50
+ def load_and_normalize_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> np.ndarray:
51
+ """
52
+ Load audio file and normalize
53
+ """
54
+ try:
55
+ # Handle different input types
56
+ if isinstance(audio_input, str):
57
+ # File path
58
+ audio_data, sr = librosa.load(audio_input, sr=self.target_sr, duration=self.duration)
59
+ elif isinstance(audio_input, tuple):
60
+ # (sample_rate, audio_array) from Gradio
61
+ sr, audio_data = audio_input
62
+
63
+ # Convert to float if needed
64
+ if audio_data.dtype != np.float32:
65
+ if audio_data.dtype == np.int16:
66
+ audio_data = audio_data.astype(np.float32) / 32767.0
67
+ elif audio_data.dtype == np.int32:
68
+ audio_data = audio_data.astype(np.float32) / 2147483647.0
69
+ else:
70
+ audio_data = audio_data.astype(np.float32)
71
+
72
+ # Resample if needed
73
+ if sr != self.target_sr:
74
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.target_sr)
75
+
76
+ # Trim to duration
77
+ if len(audio_data) > self.target_length:
78
+ audio_data = audio_data[:self.target_length]
79
+
80
+ elif isinstance(audio_input, np.ndarray):
81
+ # Raw audio array
82
+ audio_data = audio_input
83
+ if len(audio_data) > self.target_length:
84
+ audio_data = audio_data[:self.target_length]
85
+ else:
86
+ raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
87
+
88
+ # Pad if too short
89
+ if len(audio_data) < self.target_length:
90
+ audio_data = np.pad(audio_data, (0, self.target_length - len(audio_data)),
91
+ mode='constant', constant_values=0)
92
+
93
+ return audio_data
94
+
95
+ except Exception as e:
96
+ raise RuntimeError(f"Failed to load audio: {str(e)}")
97
+
98
+ def extract_mel_spectrogram(self, audio_data: np.ndarray) -> np.ndarray:
99
+ """
100
+ Extract mel spectrogram features matching training preprocessing
101
+ """
102
+ try:
103
+ # Extract mel spectrogram
104
+ mel_spec = librosa.feature.melspectrogram(
105
+ y=audio_data,
106
+ sr=self.target_sr,
107
+ n_mels=self.n_mels,
108
+ n_fft=self.n_fft,
109
+ hop_length=self.hop_length,
110
+ win_length=self.win_length,
111
+ window=self.window,
112
+ fmin=self.fmin,
113
+ fmax=self.fmax,
114
+ power=self.power
115
+ )
116
+
117
+ # Convert to log scale (dB)
118
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
119
+
120
+ return mel_spec_db
121
+
122
+ except Exception as e:
123
+ raise RuntimeError(f"Failed to extract mel spectrogram: {str(e)}")
124
+
125
+ def normalize_spectrogram(self, mel_spec: np.ndarray) -> np.ndarray:
126
+ """
127
+ Normalize mel spectrogram to match training normalization
128
+ This matches the normalization used in your training pipeline
129
+ """
130
+ # Mean and std normalization
131
+ mean = np.mean(mel_spec)
132
+ std = np.std(mel_spec)
133
+
134
+ if std == 0:
135
+ normalized = mel_spec - mean
136
+ else:
137
+ normalized = (mel_spec - mean) / std
138
+
139
+ # Clamp values to reasonable range (matching training)
140
+ normalized = np.clip(normalized, -5.0, 5.0)
141
+
142
+ return normalized
143
+
144
+ def resize_spectrogram(self, mel_spec: np.ndarray, target_width: int = 251) -> np.ndarray:
145
+ """
146
+ Resize spectrogram to target dimensions
147
+ """
148
+ current_height, current_width = mel_spec.shape
149
+
150
+ if current_width == target_width:
151
+ return mel_spec
152
+
153
+ # Use librosa's time stretching for width adjustment
154
+ if current_width < target_width:
155
+ # Pad if too narrow
156
+ pad_width = target_width - current_width
157
+ mel_spec = np.pad(mel_spec, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
158
+ else:
159
+ # Truncate if too wide
160
+ mel_spec = mel_spec[:, :target_width]
161
+
162
+ return mel_spec
163
+
164
+ def preprocess_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> torch.Tensor:
165
+ """
166
+ Complete preprocessing pipeline from audio to model input tensor
167
+ """
168
+ try:
169
+ # Step 1: Load and normalize audio
170
+ audio_data = self.load_and_normalize_audio(audio_input)
171
+
172
+ # Step 2: Extract mel spectrogram
173
+ mel_spec = self.extract_mel_spectrogram(audio_data)
174
+
175
+ # Step 3: Normalize spectrogram
176
+ mel_spec_norm = self.normalize_spectrogram(mel_spec)
177
+
178
+ # Step 4: Resize to target dimensions
179
+ mel_spec_resized = self.resize_spectrogram(mel_spec_norm)
180
+
181
+ # Step 5: Convert to tensor and add batch + channel dimensions
182
+ tensor_input = torch.FloatTensor(mel_spec_resized)
183
+ tensor_input = tensor_input.unsqueeze(0).unsqueeze(0) # Add batch and channel dims
184
+
185
+ # Verify output shape
186
+ if tensor_input.shape != self.expected_shape:
187
+ raise RuntimeError(f"Output shape {tensor_input.shape} doesn't match expected {self.expected_shape}")
188
+
189
+ return tensor_input
190
+
191
+ except Exception as e:
192
+ raise RuntimeError(f"Preprocessing failed: {str(e)}")
193
+
194
+ def get_preprocessing_info(self) -> Dict:
195
+ """
196
+ Get preprocessing configuration info
197
+ """
198
+ return {
199
+ 'target_sr': self.target_sr,
200
+ 'n_mels': self.n_mels,
201
+ 'n_fft': self.n_fft,
202
+ 'hop_length': self.hop_length,
203
+ 'duration': self.duration,
204
+ 'output_shape': self.expected_shape
205
+ }
206
+
207
+ # Example usage and testing
208
+ if __name__ == "__main__":
209
+ # Initialize preprocessor
210
+ preprocessor = RespiratoryAudioPreprocessor()
211
+
212
+ # Test with dummy audio data
213
+ dummy_audio = np.random.randn(22050 * 2) # 2 seconds of audio
214
+
215
+ try:
216
+ # Preprocess
217
+ tensor_output = preprocessor.preprocess_audio(dummy_audio)
218
+ print(f"✅ Preprocessing successful!")
219
+ print(f"Output shape: {tensor_output.shape}")
220
+ print(f"Output dtype: {tensor_output.dtype}")
221
+ print(f"Output range: [{tensor_output.min():.3f}, {tensor_output.max():.3f}]")
222
+
223
+ # Display preprocessing info
224
+ info = preprocessor.get_preprocessing_info()
225
+ print("\n📋 Preprocessing Configuration:")
226
+ for key, value in info.items():
227
+ print(f" {key}: {value}")
228
+
229
+ except Exception as e:
230
+ print(f"❌ Preprocessing failed: {e}")
main.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Backend for Respiratory Symptom Analysis
3
+ Deployed on HuggingFace Spaces for use with Netlify frontend
4
+ Updated for optimized_model_cpu folder structure
5
+ """
6
+
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import JSONResponse
10
+ import torch
11
+ import json
12
+ import numpy as np
13
+ import tempfile
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Dict, List, Any
17
+ import time
18
+ import warnings
19
+
20
+ # Import your preprocessing module
21
+ from audio_preprocessing import RespiratoryAudioPreprocessor
22
+
23
+ warnings.filterwarnings('ignore')
24
+
25
+ # Initialize FastAPI app
26
+ app = FastAPI(
27
+ title="🫁 Respiratory Symptom Analysis API",
28
+ description="AI-powered respiratory symptom detection from cough audio",
29
+ version="2.0.0",
30
+ docs_url="/docs",
31
+ redoc_url="/redoc"
32
+ )
33
+
34
+ # Add CORS middleware for Netlify frontend
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"], # Configure this for your Netlify domain in production
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ class RespiratoryAnalysisService:
44
+ """
45
+ Service class for respiratory symptom analysis
46
+ """
47
+
48
+ def __init__(self,
49
+ model_path: str = "optimized_model_cpu/model_torchscript.pt",
50
+ config_path: str = "optimized_model_cpu/model_config.json"):
51
+ """Initialize the service with model and configuration"""
52
+ self.model_path = model_path
53
+ self.config_path = config_path
54
+ self.model = None
55
+ self.config = None
56
+ self.preprocessor = None
57
+
58
+ # Load model and configuration
59
+ self.load_model_and_config()
60
+ self.setup_preprocessor()
61
+
62
+ def load_model_and_config(self):
63
+ """Load the optimized model and configuration with fallback options"""
64
+ try:
65
+ # Load configuration
66
+ if Path(self.config_path).exists():
67
+ with open(self.config_path, 'r') as f:
68
+ self.config = json.load(f)
69
+ print(f"✅ Configuration loaded from {self.config_path}")
70
+ else:
71
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
72
+
73
+ # Try loading models in priority order
74
+ model_files_to_try = [
75
+ ("optimized_model_cpu/model_torchscript.pt", "TorchScript"),
76
+ ("optimized_model_cpu/model_quantized.pt", "Quantized PyTorch"),
77
+ ("optimized_model_cpu/model_pytorch.pt", "Regular PyTorch")
78
+ ]
79
+
80
+ model_loaded = False
81
+ for model_file, model_type in model_files_to_try:
82
+ if Path(model_file).exists():
83
+ try:
84
+ if "torchscript" in model_file.lower():
85
+ # Load TorchScript model
86
+ self.model = torch.jit.load(model_file, map_location='cpu')
87
+ print(f"✅ {model_type} model loaded from {model_file}")
88
+ else:
89
+ # Load regular PyTorch model
90
+ self.model = torch.load(model_file, map_location='cpu')
91
+ print(f"✅ {model_type} model loaded from {model_file}")
92
+
93
+ self.model.eval()
94
+ model_loaded = True
95
+ break
96
+
97
+ except Exception as e:
98
+ print(f"⚠️ Failed to load {model_type} model: {str(e)}")
99
+ continue
100
+ else:
101
+ print(f"⚠️ Model file not found: {model_file}")
102
+
103
+ if not model_loaded:
104
+ raise RuntimeError("Failed to load any model file")
105
+
106
+ # Set CPU optimization
107
+ if 'optimization_settings' in self.config:
108
+ torch.set_num_threads(self.config['optimization_settings'].get('torch_threads', 4))
109
+ else:
110
+ torch.set_num_threads(4) # Default
111
+
112
+ except Exception as e:
113
+ raise RuntimeError(f"Failed to load model/config: {str(e)}")
114
+
115
+ def setup_preprocessor(self):
116
+ """Initialize audio preprocessor"""
117
+ self.preprocessor = RespiratoryAudioPreprocessor()
118
+ print("✅ Audio preprocessor initialized")
119
+
120
+ def predict_symptoms(self, audio_file_path: str) -> Dict[str, Any]:
121
+ """Predict respiratory symptoms from audio file"""
122
+ try:
123
+ start_time = time.time()
124
+
125
+ # Preprocess audio
126
+ tensor_input = self.preprocessor.preprocess_audio(audio_file_path)
127
+ preprocessing_time = time.time() - start_time
128
+
129
+ # Run inference
130
+ inference_start = time.time()
131
+ with torch.no_grad():
132
+ outputs = self.model(tensor_input)
133
+
134
+ inference_time = time.time() - inference_start
135
+
136
+ # Parse outputs based on model type
137
+ if isinstance(outputs, dict):
138
+ # New model format with dictionary output
139
+ probabilities = outputs['probabilities'].squeeze().numpy()
140
+ predictions = outputs['predictions'].squeeze().numpy()
141
+ else:
142
+ # Handle legacy model formats
143
+ if isinstance(outputs, tuple):
144
+ logits = outputs[0].squeeze() # Take first output (symptom logits)
145
+ else:
146
+ logits = outputs.squeeze()
147
+
148
+ probabilities = torch.sigmoid(logits).numpy()
149
+
150
+ # Apply thresholds
151
+ threshold_tensor = torch.tensor([
152
+ self.config['confidence_thresholds'][symptom]
153
+ for symptom in self.config['target_symptoms']
154
+ ])
155
+ predictions = (torch.sigmoid(logits) >= threshold_tensor).float().numpy()
156
+
157
+ # Format results
158
+ results = self.format_results(probabilities, predictions)
159
+
160
+ # Add timing and model info
161
+ results['processing_info'] = {
162
+ 'preprocessing_time_ms': round(preprocessing_time * 1000, 1),
163
+ 'inference_time_ms': round(inference_time * 1000, 1),
164
+ 'total_time_ms': round((preprocessing_time + inference_time) * 1000, 1),
165
+ 'model_path': self.model_path,
166
+ 'model_type': type(self.model).__name__
167
+ }
168
+
169
+ return results
170
+
171
+ except Exception as e:
172
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
173
+
174
+ def format_results(self, probabilities: np.ndarray, predictions: np.ndarray) -> Dict[str, Any]:
175
+ """Format prediction results for API response"""
176
+ results = {
177
+ 'detected_symptoms': [],
178
+ 'all_symptoms': {},
179
+ 'summary': {},
180
+ 'recommendations': []
181
+ }
182
+
183
+ # Process each symptom
184
+ for i, symptom in enumerate(self.config['target_symptoms']):
185
+ prob = float(probabilities[i])
186
+ pred = bool(predictions[i])
187
+ display_name = self.config['symptom_display_names'][symptom]
188
+ threshold = self.config['confidence_thresholds'][symptom]
189
+
190
+ # All symptoms with details
191
+ results['all_symptoms'][symptom] = {
192
+ 'display_name': display_name,
193
+ 'confidence': prob,
194
+ 'detected': pred,
195
+ 'threshold': threshold,
196
+ 'color': self.config['symptom_colors'][symptom]
197
+ }
198
+
199
+ # Detected symptoms only
200
+ if pred:
201
+ results['detected_symptoms'].append({
202
+ 'symptom': symptom,
203
+ 'display_name': display_name,
204
+ 'confidence': prob,
205
+ 'color': self.config['symptom_colors'][symptom]
206
+ })
207
+
208
+ # Sort detected symptoms by confidence
209
+ results['detected_symptoms'].sort(key=lambda x: x['confidence'], reverse=True)
210
+
211
+ # Generate summary
212
+ results['summary'] = {
213
+ 'total_detected': len(results['detected_symptoms']),
214
+ 'highest_confidence': results['detected_symptoms'][0]['confidence'] if results['detected_symptoms'] else 0.0,
215
+ 'status': 'symptoms_detected' if results['detected_symptoms'] else 'no_symptoms'
216
+ }
217
+
218
+ # Generate recommendations
219
+ if len(results['detected_symptoms']) == 0:
220
+ results['recommendations'] = [
221
+ "No significant respiratory symptoms detected.",
222
+ "Continue monitoring your health.",
223
+ "This screening is for informational purposes only."
224
+ ]
225
+ elif len(results['detected_symptoms']) == 1:
226
+ symptom_name = results['detected_symptoms'][0]['display_name']
227
+ results['recommendations'] = [
228
+ f"Detected: {symptom_name}",
229
+ "Consider monitoring symptoms and consult healthcare provider if symptoms persist.",
230
+ "This AI screening should not replace professional medical advice."
231
+ ]
232
+ else:
233
+ symptom_names = [s['display_name'] for s in results['detected_symptoms']]
234
+ results['recommendations'] = [
235
+ f"Multiple symptoms detected: {', '.join(symptom_names)}",
236
+ "Please consult a healthcare provider for proper evaluation.",
237
+ "This AI screening should not replace professional medical advice."
238
+ ]
239
+
240
+ return results
241
+
242
+ # Initialize service with error handling
243
+ print("🚀 Initializing Respiratory Analysis Service...")
244
+ try:
245
+ service = RespiratoryAnalysisService()
246
+ print("✅ Service initialized successfully!")
247
+ except Exception as e:
248
+ print(f"❌ Service initialization failed: {str(e)}")
249
+ # Create a dummy service for debugging
250
+ service = None
251
+
252
+ # API Routes
253
+ @app.get("/")
254
+ async def root():
255
+ """Root endpoint with API information"""
256
+ if service is None:
257
+ return {
258
+ "service": "Respiratory Symptom Analysis API",
259
+ "version": "2.0.0",
260
+ "status": "error - service not initialized",
261
+ "error": "Model loading failed"
262
+ }
263
+
264
+ return {
265
+ "service": "Respiratory Symptom Analysis API",
266
+ "version": "2.0.0",
267
+ "status": "active",
268
+ "endpoints": {
269
+ "analyze": "/analyze",
270
+ "health": "/health",
271
+ "info": "/info",
272
+ "docs": "/docs"
273
+ },
274
+ "supported_symptoms": list(service.config['target_symptoms']),
275
+ "model_info": {
276
+ "version": service.config['model_version'],
277
+ "optimization": "CPU-optimized with quantization",
278
+ "model_path": service.model_path
279
+ }
280
+ }
281
+
282
+ @app.get("/health")
283
+ async def health_check():
284
+ """Health check endpoint"""
285
+ return {
286
+ "status": "healthy" if service is not None else "unhealthy",
287
+ "timestamp": time.time(),
288
+ "model_loaded": service.model is not None if service else False,
289
+ "config_loaded": service.config is not None if service else False,
290
+ "model_files_available": {
291
+ "torchscript": Path("optimized_model_cpu/model_torchscript.pt").exists(),
292
+ "quantized": Path("optimized_model_cpu/model_quantized.pt").exists(),
293
+ "pytorch": Path("optimized_model_cpu/model_pytorch.pt").exists(),
294
+ "config": Path("optimized_model_cpu/model_config.json").exists()
295
+ }
296
+ }
297
+
298
+ @app.get("/info")
299
+ async def get_info():
300
+ """Get model and service information"""
301
+ if service is None:
302
+ return {"error": "Service not initialized"}
303
+
304
+ return {
305
+ "model_info": {
306
+ "version": service.config.get('model_version', '2.0'),
307
+ "target_symptoms": service.config['target_symptoms'],
308
+ "symptom_display_names": service.config['symptom_display_names'],
309
+ "confidence_thresholds": service.config['confidence_thresholds'],
310
+ "optimization_settings": service.config.get('optimization_settings', {})
311
+ },
312
+ "preprocessing_info": service.preprocessor.get_preprocessing_info(),
313
+ "supported_formats": ["wav", "mp3", "flac", "ogg", "m4a"],
314
+ "max_duration": "30 seconds",
315
+ "api_version": "2.0.0",
316
+ "model_path": service.model_path
317
+ }
318
+
319
+ @app.post("/analyze")
320
+ async def analyze_audio(audio_file: UploadFile = File(...)):
321
+ """
322
+ Analyze audio file for respiratory symptoms
323
+
324
+ Parameters:
325
+ - audio_file: Audio file (WAV, MP3, FLAC, etc.)
326
+
327
+ Returns:
328
+ - JSON response with symptom predictions and confidence scores
329
+ """
330
+ if service is None:
331
+ raise HTTPException(status_code=503, detail="Service not available - model loading failed")
332
+
333
+ # Validate file type
334
+ allowed_types = ['audio/wav', 'audio/mpeg', 'audio/flac', 'audio/ogg', 'audio/x-m4a', 'audio/mp4']
335
+ if audio_file.content_type not in allowed_types:
336
+ raise HTTPException(
337
+ status_code=400,
338
+ detail=f"Unsupported audio format: {audio_file.content_type}. Supported: {allowed_types}"
339
+ )
340
+
341
+ # Validate file size (max 10MB)
342
+ max_size = 10 * 1024 * 1024 # 10MB
343
+ content = await audio_file.read()
344
+ if len(content) > max_size:
345
+ raise HTTPException(
346
+ status_code=400,
347
+ detail="Audio file too large. Maximum size: 10MB"
348
+ )
349
+
350
+ try:
351
+ # Save uploaded file temporarily
352
+ file_extension = audio_file.filename.split('.')[-1] if audio_file.filename else 'wav'
353
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}") as temp_file:
354
+ temp_file.write(content)
355
+ temp_file_path = temp_file.name
356
+
357
+ # Analyze audio
358
+ results = service.predict_symptoms(temp_file_path)
359
+
360
+ # Clean up temporary file
361
+ os.unlink(temp_file_path)
362
+
363
+ # Return results
364
+ return JSONResponse(
365
+ status_code=200,
366
+ content={
367
+ "success": True,
368
+ "data": results,
369
+ "metadata": {
370
+ "filename": audio_file.filename,
371
+ "file_size_bytes": len(content),
372
+ "content_type": audio_file.content_type,
373
+ "timestamp": time.time()
374
+ }
375
+ }
376
+ )
377
+
378
+ except HTTPException:
379
+ # Re-raise HTTP exceptions
380
+ raise
381
+ except Exception as e:
382
+ # Clean up temporary file if it exists
383
+ if 'temp_file_path' in locals():
384
+ try:
385
+ os.unlink(temp_file_path)
386
+ except:
387
+ pass
388
+
389
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
390
+
391
+ # Error handler
392
+ @app.exception_handler(Exception)
393
+ async def global_exception_handler(request, exc):
394
+ """Global exception handler"""
395
+ return JSONResponse(
396
+ status_code=500,
397
+ content={
398
+ "success": False,
399
+ "error": "Internal server error",
400
+ "detail": str(exc) if app.debug else "An unexpected error occurred"
401
+ }
402
+ )
403
+
404
+ if __name__ == "__main__":
405
+ import uvicorn
406
+
407
+ # Run the API server
408
+ uvicorn.run(
409
+ "main:app",
410
+ host="0.0.0.0",
411
+ port=7860, # HuggingFace Spaces default port
412
+ reload=False
413
+ )