App-api / app.py
trixy194t's picture
Update app.py
817d977 verified
import os
import shutil
import numpy as np
import librosa
import noisereduce as nr
import gc
import tensorflow as tf
from tensorflow.keras import layers, models
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# --- OPTIMIZE FOR CPU ---
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
app = FastAPI(title="Real Snore & Apnea Detector")
# Enable CORS for frontend integration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- MODEL RECONSTRUCTION ---
def build_autoencoder_from_training_spec():
"""Reconstructs the model architecture exactly as used in training."""
model = models.Sequential([
layers.Input(shape=(128, 626, 1)),
layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2), padding='same'),
layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2), padding='same'),
layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding='same'),
layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same'),
layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same'),
layers.Resizing(128, 626)
])
return model
MODEL_PATH = "snore_detection_model.keras"
snore_model = None
# Load model on startup
try:
snore_model = build_autoencoder_from_training_spec()
snore_model.load_weights(MODEL_PATH)
print("✅ Real Snore Detector Loaded Successfully!")
except Exception as e:
print(f"❌ Model Load Error: {e}")
# --- AI VALIDATION LOGIC ---
def is_snore_segment(y_seg, sr=16000):
"""
Uses Reconstruction Error (MSE) to separate real snores from noise.
Rejects keyboard typing, walking, and talking by comparing them to the
learned spectral signature of snoring.
"""
if snore_model is None: return False, 0.0
# 1. Generate Mel Spectrogram
S = librosa.feature.melspectrogram(y=y_seg, sr=sr, n_mels=128)
log_S = librosa.power_to_db(S, ref=np.max)
# 2. Reshape to (128, 626) to match model input
target_width = 626
if log_S.shape[1] < target_width:
log_S = np.pad(log_S, ((0, 0), (0, target_width - log_S.shape[1])), mode='constant', constant_values=-80)
else:
log_S = log_S[:, :target_width]
# 3. Normalize and predict
input_data = (log_S + 80.0) / 80.0
input_data = input_data.reshape(1, 128, 626, 1)
reconstructed = snore_model.predict(input_data, verbose=0)
mse = float(np.mean(np.square(input_data - reconstructed)))
# 4. Filter: Real snores typically have MSE < 0.0018.
# Sudden noises like typing result in much higher reconstruction errors.
is_real = mse < 0.0018
confidence = max(0, 1 - (mse * 100))
return is_real, confidence
# --- API ENDPOINTS ---
@app.get("/")
async def root():
"""Health check endpoint to prevent 404 errors."""
return {"status": "online", "model_loaded": snore_model is not None}
@app.post("/analyze")
async def analyze_audio(file: UploadFile = File(...)):
temp_path = f"temp_{os.getpid()}_{file.filename}"
# Save the uploaded file temporarily
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# 1. LOAD: 16k SR saves significant RAM
y, sr = librosa.load(temp_path, sr=16000, dtype=np.float32)
# 2. CLEAN: Use stationary noise reduction for speed
y_clean = nr.reduce_noise(y=y, sr=sr, stationary=True)
# 3. SEGMENT: Identify sound bursts
intervals = librosa.effects.split(y_clean, top_db=25)
annotations = []
prev_end = 0
snore_count = 0
apnea_count = 0
# 3-SECOND COOLDOWN: Prevents multiple counts for one breath
last_snore_time = -3.0
for start, end in intervals:
current_time = start / sr
# --- APNEA LOGIC (Gap Detection) ---
gap_dur = (start - prev_end) / sr
if 10.0 <= gap_dur <= 120.0:
apnea_count += 1
risk = "LOW" if gap_dur < 15.0 else ("MEDIUM" if gap_dur < 20.0 else "HIGH")
annotations.append({
"label": "APNEA",
"start_sec": round(prev_end/sr, 2),
"end_sec": round(start/sr, 2),
"duration": round(gap_dur, 2),
"risk_level": risk
})
# --- REAL SNORING LOGIC (AI Validation) ---
# Check the 3-second lockout window first
if (current_time - last_snore_time) >= 3.0:
seg = y[start : min(start + 16000, end)]
if len(seg) > 1600: # Ignore clicks shorter than 0.1s
# Use the Autoencoder to verify it's a real snore
is_snore, conf = is_snore_segment(seg, sr)
if is_snore:
snore_count += 1
last_snore_time = current_time # Start the 3-second lockout
annotations.append({
"label": "SNORING",
"start_sec": round(start/sr, 2),
"end_sec": round(end/sr, 2),
"duration": round((end-start)/sr, 2),
"confidence": round(conf, 4)
})
prev_end = end
# 4. STATS: Calculate Apnea-Hypopnea Index (AHI)
duration_hours = (len(y) / sr) / 3600
ahi = apnea_count / duration_hours if duration_hours > 0 else 0
risk_summary = "NORMAL"
if ahi >= 20: risk_summary = "HIGH"
elif ahi >= 15: risk_summary = "MEDIUM"
elif ahi >= 10: risk_summary = "LOW"
# Explicit RAM cleanup before returning result
del y, y_clean
gc.collect()
return {
"valid_recording": True,
"snore_count": snore_count,
"apnea_count": apnea_count,
"overall_risk": risk_summary,
"ahi_score": round(ahi, 1),
"events": annotations
}
except Exception as e:
print(f"🔥 ERROR: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Processing Error")
finally:
# Final cleanup: Remove temp file and free memory
if os.path.exists(temp_path):
os.remove(temp_path)
gc.collect()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)