Solomon17705's picture
updated app.py
05e4a81
# app.py
import io
import json
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from tensorflow.keras.models import load_model
import numpy as np
import librosa
# ---------------- CONFIG ----------------
MODEL_PATH = "resp_model.h5"
CLASS_NAMES = [
"Bronchiectasis",
"Bronchiolitis",
"COPD",
"Healthy",
"Pneumonia",
"URTI"
]
SR = 22050
N_MFCC = 40
MAX_PAD_LEN = 862
CHUNK_DURATION = 4.0
MIN_CONFIDENCE = 0.5 # Ignore low-confidence chunks
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
# -----------------------------------------
# Load model once at startup
try:
model = load_model(MODEL_PATH)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
app = FastAPI(title="Respiratory Disease Prediction API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def extract_mfcc(audio_bytes):
"""Extract 40 MFCCs and pad/truncate to 862 timesteps"""
try:
audio, _ = librosa.load(io.BytesIO(audio_bytes), sr=SR, mono=True, duration=20)
mfcc = librosa.feature.mfcc(y=audio, sr=SR, n_mfcc=N_MFCC)
if mfcc.shape[1] < MAX_PAD_LEN:
mfcc = np.pad(mfcc, ((0, 0), (0, MAX_PAD_LEN - mfcc.shape[1])), mode='constant')
else:
mfcc = mfcc[:, :MAX_PAD_LEN]
return mfcc # shape: (40, 862)
except Exception as e:
raise ValueError(f"Audio processing failed: {e}")
def split_audio_chunks(audio, sr, chunk_duration=CHUNK_DURATION):
"""Split audio into fixed-duration non-overlapping chunks"""
chunk_samples = int(chunk_duration * sr)
chunks = []
for i in range(0, len(audio), chunk_samples):
chunk = audio[i:i + chunk_samples]
if len(chunk) >= sr: # at least 1 second
chunks.append(chunk)
return chunks
def calculate_risk_assessment(disease, symptoms, duration, smoker, severity):
"""Simple rule-based risk assessment"""
risk_score = 0
# Disease severity weights
disease_weights = {
"COPD": 3, "Pneumonia": 3,
"Bronchiectasis": 2, "Bronchiolitis": 2,
"URTI": 1, "Healthy": 0
}
risk_score += disease_weights.get(disease, 0)
# Symptom weights
symptom_weights = {
"Shortness of Breath": 2,
"Wheezing": 1.5,
"Chest Pain": 2,
"Fever": 1,
"Fatigue": 0.5,
"Sore Throat": 0.5,
"Nasal Congestion": 0.5
}
for symptom in symptoms:
risk_score += symptom_weights.get(symptom, 0)
# Duration
if duration == "More than a week":
risk_score += 2
elif duration == "3-7 days":
risk_score += 1
# Smoker
if smoker == "Yes":
risk_score += 2
# Severity (1-10 scale)
risk_score += severity / 2
# Determine risk level
if risk_score >= 10:
risk_level = "Severe"
elif risk_score >= 5:
risk_level = "Moderate"
else:
risk_level = "Mild"
return {
"risk_level": risk_level,
"risk_score": round(risk_score, 2),
"message": f"Based on your cough sound and symptoms, your condition is assessed as {risk_level}."
}
@app.post("/predict")
async def predict_respiratory_disease(
file: UploadFile = File(...),
symptoms: str = Form("[]"), # JSON string of symptoms
duration: str = Form("3-7 days"), # Duration
smoker: str = Form("No"), # "Yes" or "No"
severity: int = Form(5) # 1-10 scale
):
# Validate file type
if not file.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
raise HTTPException(status_code=400, detail="Only audio files allowed (.wav, .mp3, .ogg, .flac)")
# Validate file size
audio_bytes = await file.read()
if len(audio_bytes) > MAX_FILE_SIZE:
raise HTTPException(status_code=400, detail="File too large (>10 MB)")
try:
# Load audio
audio_buffer = io.BytesIO(audio_bytes)
audio, _ = librosa.load(audio_buffer, sr=SR, mono=True)
# Check for silent audio
if np.max(np.abs(audio)) < 0.01:
return {
"disease": "Healthy",
"confidence": 0.99,
"probabilities": {cls: 0.0 for cls in CLASS_NAMES},
"user_input": {
"symptoms": json.loads(symptoms),
"duration": duration,
"smoker": smoker,
"severity": severity
},
"assessment": {
"risk_level": "Mild",
"risk_score": 0.0,
"message": "No significant sound detected. Likely healthy."
},
"warning": "Silent or very quiet audio detected"
}
# Split into chunks
chunks = split_audio_chunks(audio, SR, CHUNK_DURATION)
if not chunks:
raise HTTPException(status_code=400, detail="Audio too short (<1 second)")
# Run inference on each chunk
predictions = []
for chunk in chunks:
mfcc = extract_mfcc(io.BytesIO(librosa.core.audio.to_wav(chunk, sr=SR)).getvalue())
mfcc = np.expand_dims(mfcc, axis=0)
mfcc = np.expand_dims(mfcc, axis=-1)
pred = model.predict(mfcc, verbose=0)[0]
# Only keep high-confidence predictions
if np.max(pred) >= MIN_CONFIDENCE:
predictions.append(pred)
if not predictions:
return {
"disease": "Uncertain",
"confidence": 0.0,
"probabilities": {cls: 0.0 for cls in CLASS_NAMES},
"user_input": {
"symptoms": json.loads(symptoms),
"duration": duration,
"smoker": smoker,
"severity": severity
},
"assessment": {
"risk_level": "Mild",
"risk_score": 0.0,
"message": "No clear respiratory pattern detected. Consider re-recording."
},
"warning": "All chunks had low confidence"
}
# Average high-confidence predictions
avg_pred = np.mean(predictions, axis=0)
predicted_class = CLASS_NAMES[int(np.argmax(avg_pred))]
confidence = float(np.max(avg_pred))
# Parse user input
symptoms_list = json.loads(symptoms)
# Calculate risk assessment
assessment = calculate_risk_assessment(predicted_class, symptoms_list, duration, smoker, severity)
return {
"disease": predicted_class,
"confidence": round(confidence, 4),
"probabilities": {
cls: float(avg_pred[i]) for i, cls in enumerate(CLASS_NAMES)
},
"user_input": {
"symptoms": symptoms_list,
"duration": duration,
"smoker": smoker,
"severity": severity
},
"assessment": assessment,
"chunks_analyzed": len(chunks),
"usable_chunks": len(predictions)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
# Required for Hugging Face Spaces
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)