import os import shutil import numpy as np import librosa import noisereduce as nr import gc import tensorflow as tf from tensorflow.keras import layers, models from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware # --- OPTIMIZE FOR CPU --- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' app = FastAPI(title="Real Snore & Apnea Detector") # Enable CORS for frontend integration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # --- MODEL RECONSTRUCTION --- def build_autoencoder_from_training_spec(): """Reconstructs the model architecture exactly as used in training.""" model = models.Sequential([ layers.Input(shape=(128, 626, 1)), layers.Conv2D(32, (3, 3), activation='relu', padding='same'), layers.MaxPooling2D((2, 2), padding='same'), layers.Conv2D(64, (3, 3), activation='relu', padding='same'), layers.MaxPooling2D((2, 2), padding='same'), layers.Conv2D(128, (3, 3), activation='relu', padding='same'), layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding='same'), layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same'), layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same'), layers.Resizing(128, 626) ]) return model MODEL_PATH = "snore_detection_model.keras" snore_model = None # Load model on startup try: snore_model = build_autoencoder_from_training_spec() snore_model.load_weights(MODEL_PATH) print("✅ Real Snore Detector Loaded Successfully!") except Exception as e: print(f"❌ Model Load Error: {e}") # --- AI VALIDATION LOGIC --- def is_snore_segment(y_seg, sr=16000): """ Uses Reconstruction Error (MSE) to separate real snores from noise. Rejects keyboard typing, walking, and talking by comparing them to the learned spectral signature of snoring. """ if snore_model is None: return False, 0.0 # 1. Generate Mel Spectrogram S = librosa.feature.melspectrogram(y=y_seg, sr=sr, n_mels=128) log_S = librosa.power_to_db(S, ref=np.max) # 2. Reshape to (128, 626) to match model input target_width = 626 if log_S.shape[1] < target_width: log_S = np.pad(log_S, ((0, 0), (0, target_width - log_S.shape[1])), mode='constant', constant_values=-80) else: log_S = log_S[:, :target_width] # 3. Normalize and predict input_data = (log_S + 80.0) / 80.0 input_data = input_data.reshape(1, 128, 626, 1) reconstructed = snore_model.predict(input_data, verbose=0) mse = float(np.mean(np.square(input_data - reconstructed))) # 4. Filter: Real snores typically have MSE < 0.0018. # Sudden noises like typing result in much higher reconstruction errors. is_real = mse < 0.0018 confidence = max(0, 1 - (mse * 100)) return is_real, confidence # --- API ENDPOINTS --- @app.get("/") async def root(): """Health check endpoint to prevent 404 errors.""" return {"status": "online", "model_loaded": snore_model is not None} @app.post("/analyze") async def analyze_audio(file: UploadFile = File(...)): temp_path = f"temp_{os.getpid()}_{file.filename}" # Save the uploaded file temporarily with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # 1. LOAD: 16k SR saves significant RAM y, sr = librosa.load(temp_path, sr=16000, dtype=np.float32) # 2. CLEAN: Use stationary noise reduction for speed y_clean = nr.reduce_noise(y=y, sr=sr, stationary=True) # 3. SEGMENT: Identify sound bursts intervals = librosa.effects.split(y_clean, top_db=25) annotations = [] prev_end = 0 snore_count = 0 apnea_count = 0 # 3-SECOND COOLDOWN: Prevents multiple counts for one breath last_snore_time = -3.0 for start, end in intervals: current_time = start / sr # --- APNEA LOGIC (Gap Detection) --- gap_dur = (start - prev_end) / sr if 10.0 <= gap_dur <= 120.0: apnea_count += 1 risk = "LOW" if gap_dur < 15.0 else ("MEDIUM" if gap_dur < 20.0 else "HIGH") annotations.append({ "label": "APNEA", "start_sec": round(prev_end/sr, 2), "end_sec": round(start/sr, 2), "duration": round(gap_dur, 2), "risk_level": risk }) # --- REAL SNORING LOGIC (AI Validation) --- # Check the 3-second lockout window first if (current_time - last_snore_time) >= 3.0: seg = y[start : min(start + 16000, end)] if len(seg) > 1600: # Ignore clicks shorter than 0.1s # Use the Autoencoder to verify it's a real snore is_snore, conf = is_snore_segment(seg, sr) if is_snore: snore_count += 1 last_snore_time = current_time # Start the 3-second lockout annotations.append({ "label": "SNORING", "start_sec": round(start/sr, 2), "end_sec": round(end/sr, 2), "duration": round((end-start)/sr, 2), "confidence": round(conf, 4) }) prev_end = end # 4. STATS: Calculate Apnea-Hypopnea Index (AHI) duration_hours = (len(y) / sr) / 3600 ahi = apnea_count / duration_hours if duration_hours > 0 else 0 risk_summary = "NORMAL" if ahi >= 20: risk_summary = "HIGH" elif ahi >= 15: risk_summary = "MEDIUM" elif ahi >= 10: risk_summary = "LOW" # Explicit RAM cleanup before returning result del y, y_clean gc.collect() return { "valid_recording": True, "snore_count": snore_count, "apnea_count": apnea_count, "overall_risk": risk_summary, "ahi_score": round(ahi, 1), "events": annotations } except Exception as e: print(f"🔥 ERROR: {str(e)}") raise HTTPException(status_code=500, detail="Internal Processing Error") finally: # Final cleanup: Remove temp file and free memory if os.path.exists(temp_path): os.remove(temp_path) gc.collect() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)