trixy194t commited on
Commit
96339a8
·
verified ·
1 Parent(s): d00d112

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -61
app.py CHANGED
@@ -9,6 +9,8 @@ import torch.nn as nn
9
  from torchvision import models
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
 
 
12
 
13
  app = FastAPI()
14
 
@@ -27,19 +29,21 @@ def load_model():
27
  model = models.efficientnet_b0(weights=None)
28
  num_ftrs = model.classifier[1].in_features
29
  model.classifier[1] = nn.Linear(num_ftrs, 1)
 
30
  try:
31
  state_dict = torch.load(MODEL_PATH, map_location=device)
32
  model.load_state_dict(state_dict, strict=False)
33
- print("✅ PyTorch EfficientNet Loaded Successfully")
34
  except Exception as e:
35
  print(f"❌ Load Error: {e}")
 
36
  model.to(device)
37
  model.eval()
38
  return model
39
 
40
  model = load_model()
41
 
42
- # --- LOGIC FUNCTIONS ---
43
 
44
  def clean_audio_stream(y, sr=16000):
45
  y_denoised = nr.reduce_noise(y=y, sr=sr)
@@ -47,59 +51,185 @@ def clean_audio_stream(y, sr=16000):
47
  y_filtered = signal.filtfilt(b, a, y_denoised)
48
  return y_filtered
49
 
50
- def detect_snoring_sliding_window(y_segment, sr):
51
- """
52
- Scans a segment using a sliding window for higher precision.
53
- Returns True if a snore is found within the segment.
54
- """
55
- WINDOW_SIZE = 3.0 # 1 second windows
56
- STEP_SIZE = 0.25 # 0.25 second steps for high resolution
57
- THRESHOLD = 0.62 # Strict threshold as per your provided logic
58
-
59
- samples_window = int(WINDOW_SIZE * sr)
60
- samples_step = int(STEP_SIZE * sr)
61
-
62
- if len(y_segment) < samples_window:
63
- return False, 0.0
64
 
65
- best_conf = 0.0
66
- found_snore = False
 
 
67
 
68
- # Sliding through the segment
69
- for i in range(0, len(y_segment) - samples_window, samples_step):
70
- chunk = y_segment[i : i + samples_window]
71
-
72
- # RMS Gate
73
- if np.sqrt(np.mean(chunk**2)) < 0.002:
74
- continue
75
-
76
- # Pre-process
77
- y_fixed = librosa.util.fix_length(chunk, size=16000)
78
  S = librosa.feature.melspectrogram(y=y_fixed, sr=16000, n_mels=128)
79
  S_db = librosa.power_to_db(S, ref=np.max)
 
 
 
80
  S_norm = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-6)
81
 
82
- input_tensor = torch.tensor(S_norm).float().unsqueeze(0).unsqueeze(0)
83
- input_tensor = input_tensor.repeat(1, 3, 1, 1).to(device)
84
-
85
- with torch.no_grad():
86
- output = model(input_tensor)
87
- conf = torch.sigmoid(output).item()
88
-
89
- if conf > best_conf:
90
- best_conf = conf
91
 
92
- if conf > THRESHOLD:
93
- found_snore = True
 
94
 
95
- return found_snore, round(best_conf, 2)
 
 
 
 
 
96
 
97
  def validate_sleep_recording(y, sr):
98
  duration = len(y) / sr
99
- if duration < 20: return False, "Audio too short"
100
  if np.sqrt(np.mean(y**2)) < 0.001: return False, "Audio is blank"
101
  return True, "Valid"
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # --- API ENDPOINTS ---
104
 
105
  @app.post("/analyze")
@@ -116,50 +246,63 @@ async def analyze_audio(file: UploadFile = File(...)):
116
  return {"valid_recording": False, "reason": reason}
117
 
118
  y_clean = clean_audio_stream(y_orig, sr)
 
 
 
 
 
 
119
  intervals = librosa.effects.split(y_clean, top_db=25)
120
-
121
  annotations = []
122
  prev_end = 0
123
- snore_count = 0
124
  apnea_count = 0
125
 
126
  for start, end in intervals:
127
- # --- APNEA LOGIC ---
128
  gap_dur = (start - prev_end) / sr
129
  if 10.0 <= gap_dur <= 120.0:
130
  apnea_count += 1
131
- risk = "LOW" if gap_dur < 15.0 else ("MEDIUM" if gap_dur < 20.0 else "HIGH")
 
 
 
 
 
 
 
 
132
  annotations.append({
133
  "label": "APNEA",
134
  "start_sec": round(prev_end/sr, 2),
135
  "end_sec": round(start/sr, 2),
136
  "duration": round(gap_dur, 2),
137
- "risk_level": risk
138
- })
139
-
140
- # --- SNORING LOGIC (Using Sliding Window) ---
141
- seg = y_orig[start:end]
142
- is_snore, conf = detect_snoring_sliding_window(seg, sr)
143
- if is_snore:
144
- snore_count += 1
145
- annotations.append({
146
- "label": "SNORING",
147
- "start_sec": round(start/sr, 2),
148
- "end_sec": round(end/sr, 2),
149
- "duration": round((end-start)/sr, 2),
150
- "confidence": conf
151
  })
 
152
  prev_end = end
 
 
 
 
 
 
 
 
 
 
153
 
154
- # Stats logic
155
  duration_hours = (len(y_orig) / sr) / 3600
156
  ahi = apnea_count / duration_hours if duration_hours > 0 else 0
157
-
 
158
  overall_risk = ""
159
  if ahi >= 20: overall_risk = "HIGH"
160
  elif ahi >= 15: overall_risk = "MEDIUM"
161
  elif ahi >= 10: overall_risk = "LOW"
162
 
 
163
  return {
164
  "valid_recording": True,
165
  "snore_count": snore_count,
 
9
  from torchvision import models
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from scipy.ndimage import gaussian_filter1d
13
+ from scipy.signal import find_peaks
14
 
15
  app = FastAPI()
16
 
 
29
  model = models.efficientnet_b0(weights=None)
30
  num_ftrs = model.classifier[1].in_features
31
  model.classifier[1] = nn.Linear(num_ftrs, 1)
32
+
33
  try:
34
  state_dict = torch.load(MODEL_PATH, map_location=device)
35
  model.load_state_dict(state_dict, strict=False)
36
+ print("✅ PyTorch EfficientNet Loaded")
37
  except Exception as e:
38
  print(f"❌ Load Error: {e}")
39
+
40
  model.to(device)
41
  model.eval()
42
  return model
43
 
44
  model = load_model()
45
 
46
+ # --- ORIGINAL LOGIC FUNCTIONS ---
47
 
48
  def clean_audio_stream(y, sr=16000):
49
  y_denoised = nr.reduce_noise(y=y, sr=sr)
 
51
  y_filtered = signal.filtfilt(b, a, y_denoised)
52
  return y_filtered
53
 
54
+ def is_snoring_sound_pytorch(y_segment, sr):
55
+ """Refined PyTorch detection to allow real snores while blocking background noise"""
56
+ try:
57
+ # 1. FIX SENSITIVITY: Lowered RMS threshold from 0.008 to 0.002
58
+ # This allows quieter snores to be processed by the AI.
59
+ rms = np.sqrt(np.mean(y_segment**2))
60
+ if rms < 0.002:
61
+ return False, 0.0
 
 
 
 
 
 
62
 
63
+ # 2. Pre-process for EfficientNet
64
+ if sr != 16000:
65
+ y_segment = librosa.resample(y_segment, orig_sr=sr, target_sr=16000)
66
+ y_fixed = librosa.util.fix_length(y_segment, size=16000)
67
 
68
+ # 3. Create Mel Spectrogram
 
 
 
 
 
 
 
 
 
69
  S = librosa.feature.melspectrogram(y=y_fixed, sr=16000, n_mels=128)
70
  S_db = librosa.power_to_db(S, ref=np.max)
71
+
72
+ # 4. Normalization (Crucial for EfficientNet)
73
+ # We add a small epsilon (1e-6) to prevent division by zero
74
  S_norm = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-6)
75
 
76
+ input_tensor = torch.tensor(S_norm).float().unsqueeze(0).unsqueeze(0)
77
+ input_tensor = input_tensor.repeat(1, 3, 1, 1) # RGB-like format
 
 
 
 
 
 
 
78
 
79
+ with torch.no_grad():
80
+ output = model(input_tensor.to(device))
81
+ confidence = torch.sigmoid(output).item()
82
 
83
+ # 5. ADJUSTED CONFIDENCE: Lowered from 0.7 to 0.5
84
+ # This makes the AI less "hesitant" to label a sound as a snore.
85
+ return confidence > 0.5, round(confidence, 2)
86
+ except Exception as e:
87
+ print(f"Inference error: {e}")
88
+ return False, 0.0
89
 
90
  def validate_sleep_recording(y, sr):
91
  duration = len(y) / sr
92
+ if duration < 20: return False, "Audio too short (< 20s)"
93
  if np.sqrt(np.mean(y**2)) < 0.001: return False, "Audio is blank"
94
  return True, "Valid"
95
 
96
+ # --- NEW ACCURATE SNORE DETECTION FUNCTIONS ---
97
+
98
+ def segment_audio(audio, sr, segment_duration=1.5, overlap=0.67):
99
+ """Split audio into overlapping segments"""
100
+ segment_samples = int(segment_duration * sr)
101
+ hop_samples = int(segment_samples * (1 - overlap))
102
+
103
+ segments = []
104
+ timestamps = []
105
+
106
+ for start in range(0, len(audio) - segment_samples + 1, hop_samples):
107
+ end = start + segment_samples
108
+ segment = audio[start:end]
109
+ segments.append(segment)
110
+ timestamps.append(start / sr)
111
+
112
+ return segments, timestamps
113
+
114
+ def calculate_audio_features(segment, sr):
115
+ """Calculate comprehensive audio features"""
116
+ energy = np.sum(segment ** 2) / len(segment)
117
+ rms = np.sqrt(np.mean(segment ** 2))
118
+ zcr = np.sum(np.abs(np.diff(np.sign(segment)))) / (2 * len(segment))
119
+
120
+ return {
121
+ 'energy': energy,
122
+ 'rms': rms,
123
+ 'zcr': zcr
124
+ }
125
+
126
+ def detect_snores_accurate(y_clean, sr):
127
+ """
128
+ Accurate snore detection using audio features + peak detection
129
+ Returns list of snore events with timestamps
130
+ """
131
+ # Segment the audio
132
+ segments, timestamps = segment_audio(y_clean, sr, segment_duration=1.5, overlap=0.67)
133
+
134
+ # Extract features for all segments
135
+ all_features = []
136
+
137
+ for i, (segment, timestamp) in enumerate(zip(segments, timestamps)):
138
+ features = calculate_audio_features(segment, sr)
139
+ features['timestamp'] = timestamp
140
+
141
+ # Get model prediction as additional feature
142
+ is_snore, conf = is_snoring_sound_pytorch(segment, sr)
143
+ features['snore_prob'] = conf
144
+
145
+ all_features.append(features)
146
+
147
+ # Convert to arrays
148
+ energies = np.array([f['energy'] for f in all_features])
149
+ rms_values = np.array([f['rms'] for f in all_features])
150
+ zcr_values = np.array([f['zcr'] for f in all_features])
151
+
152
+ # Normalize features
153
+ energy_norm = (energies - energies.min()) / (energies.max() - energies.min() + 1e-8)
154
+ rms_norm = (rms_values - rms_values.min()) / (rms_values.max() - rms_values.min() + 1e-8)
155
+ zcr_norm = 1 - (zcr_values - zcr_values.min()) / (zcr_values.max() - zcr_values.min() + 1e-8)
156
+
157
+ # Create composite score: Energy (40%) + RMS (40%) + Low ZCR (20%)
158
+ composite_score = energy_norm * 0.4 + rms_norm * 0.4 + zcr_norm * 0.2
159
+
160
+ # Smooth the score
161
+ smoothed_score = gaussian_filter1d(composite_score, sigma=1.2)
162
+
163
+ # Find peaks (individual snores)
164
+ peak_height = np.percentile(smoothed_score, 50)
165
+ peak_distance = int(0.8 / 0.5) # Minimum 0.8 seconds between peaks
166
+ peak_prominence = 0.04
167
+ peak_width = (0.5, 8)
168
+
169
+ peaks, properties = find_peaks(
170
+ smoothed_score,
171
+ height=peak_height,
172
+ distance=peak_distance,
173
+ prominence=peak_prominence,
174
+ width=peak_width
175
+ )
176
+
177
+ # Create snore events from peaks
178
+ snore_events = []
179
+
180
+ for peak_idx in peaks:
181
+ feature = all_features[peak_idx]
182
+
183
+ # Find event boundaries (tight around peak)
184
+ start_idx = peak_idx
185
+ end_idx = peak_idx
186
+
187
+ threshold = smoothed_score[peak_idx] * 0.5
188
+
189
+ # Find start
190
+ for i in range(peak_idx, max(0, peak_idx - 3), -1):
191
+ if smoothed_score[i] < threshold:
192
+ start_idx = i + 1
193
+ break
194
+ start_idx = i
195
+
196
+ # Find end
197
+ for i in range(peak_idx, min(len(smoothed_score), peak_idx + 3)):
198
+ if smoothed_score[i] < threshold:
199
+ end_idx = i
200
+ break
201
+ end_idx = i
202
+
203
+ # Calculate timestamps
204
+ start_time = all_features[start_idx]['timestamp']
205
+ end_time = all_features[end_idx]['timestamp'] + 1.0
206
+
207
+ # Only merge if events overlap significantly
208
+ should_add = True
209
+ for existing in snore_events:
210
+ if start_time < existing['end_time'] - 0.3: # Overlaps by more than 0.3s
211
+ # Update existing event instead of adding new one
212
+ existing['end_time'] = max(existing['end_time'], end_time)
213
+ existing['confidence'] = max(existing['confidence'], feature['snore_prob'])
214
+ should_add = False
215
+ break
216
+
217
+ if should_add:
218
+ duration = end_time - start_time
219
+ if duration >= 0.5: # Minimum duration
220
+ snore_events.append({
221
+ 'start_time': start_time,
222
+ 'end_time': end_time,
223
+ 'duration': duration,
224
+ 'confidence': feature['snore_prob'],
225
+ 'composite_score': smoothed_score[peak_idx]
226
+ })
227
+
228
+ # Sort by timestamp
229
+ snore_events = sorted(snore_events, key=lambda x: x['start_time'])
230
+
231
+ return snore_events
232
+
233
  # --- API ENDPOINTS ---
234
 
235
  @app.post("/analyze")
 
246
  return {"valid_recording": False, "reason": reason}
247
 
248
  y_clean = clean_audio_stream(y_orig, sr)
249
+
250
+ # --- NEW: Use accurate snore detection ---
251
+ snore_events = detect_snores_accurate(y_clean, sr)
252
+ snore_count = len(snore_events)
253
+
254
+ # --- ORIGINAL APNEA DETECTION (unchanged) ---
255
  intervals = librosa.effects.split(y_clean, top_db=25)
256
+
257
  annotations = []
258
  prev_end = 0
 
259
  apnea_count = 0
260
 
261
  for start, end in intervals:
262
+ # --- APNEA LOGIC (unchanged) ---
263
  gap_dur = (start - prev_end) / sr
264
  if 10.0 <= gap_dur <= 120.0:
265
  apnea_count += 1
266
+
267
+ # Risk level per event
268
+ if gap_dur < 15.0:
269
+ current_risk = "LOW"
270
+ elif gap_dur < 20.0:
271
+ current_risk = "MEDIUM"
272
+ else:
273
+ current_risk = "HIGH"
274
+
275
  annotations.append({
276
  "label": "APNEA",
277
  "start_sec": round(prev_end/sr, 2),
278
  "end_sec": round(start/sr, 2),
279
  "duration": round(gap_dur, 2),
280
+ "risk_level": current_risk
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  })
282
+
283
  prev_end = end
284
+
285
+ # --- Add detected snores to annotations ---
286
+ for snore in snore_events:
287
+ annotations.append({
288
+ "label": "SNORING",
289
+ "start_sec": round(snore['start_time'], 2),
290
+ "end_sec": round(snore['end_time'], 2),
291
+ "duration": round(snore['duration'], 2),
292
+ "confidence": round(snore['confidence'], 2)
293
+ })
294
 
295
+ # Calculate AHI Metrics (unchanged)
296
  duration_hours = (len(y_orig) / sr) / 3600
297
  ahi = apnea_count / duration_hours if duration_hours > 0 else 0
298
+
299
+ # --- Risk Level based on frequency (unchanged) ---
300
  overall_risk = ""
301
  if ahi >= 20: overall_risk = "HIGH"
302
  elif ahi >= 15: overall_risk = "MEDIUM"
303
  elif ahi >= 10: overall_risk = "LOW"
304
 
305
+ # --- FINAL RESPONSE ---
306
  return {
307
  "valid_recording": True,
308
  "snore_count": snore_count,