KasaHealth / app /main.py
78anand's picture
Upload folder using huggingface_hub
ef12225 verified
import os
import sys
# --- Force Writable Paths for Hugging Face ---
os.environ['HOME'] = '/tmp'
os.environ['HF_HOME'] = '/tmp/huggingface'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba'
# Ensure directories exist
for d in ['/tmp/huggingface', '/tmp/cache', '/tmp/matplotlib', '/tmp/numba']:
os.makedirs(d, exist_ok=True)
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import io
import base64
import tensorflow as tf
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from tensorflow.keras.models import load_model
from werkzeug.utils import secure_filename
from scipy.signal import butter, lfilter
# --- Absolute Path Resolution ---
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
try:
from utils.hear_extractor import HeARExtractor
print("✅ Successfully imported utils package.")
except ImportError as e:
print(f"❌ Critical Import Error: {e}")
sys.exit(1)
app = Flask(__name__)
CORS(app)
app.config['UPLOAD_FOLDER'] = os.path.join('/tmp', 'uploads')
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# --- NEW: Direct CSV Download Endpoint ---
@app.route('/download_data', methods=['GET'])
def download_data():
log_file = os.path.join(os.path.dirname(__file__), 'pilot_data.csv')
if os.path.exists(log_file):
return send_file(log_file, as_attachment=True)
return "No pilot data has been gathered yet! Run a test first.", 404
# Configuration: DUAL-BRAIN MODELS
# Brain 1: The Shield (V9) - Protects healthy users
# Brain 2: The Sentry (V10) - High sensitivity for sick patients
MODEL_V9_PATH = os.path.join(project_root, "models", "hear_classifier_v9_ultimate.h5")
MODEL_V10_PATH = os.path.join(project_root, "models", "hear_classifier_v10_sentry.h5")
# Global variables for lazy loading
extractor = None
shield_model = None
sentry_model = None
def load_resources():
global extractor, shield_model, sentry_model
if extractor is None:
hf_token = os.environ.get('HF_TOKEN')
extractor = HeARExtractor(token=hf_token)
if shield_model is None:
print(f"Loading Shield Model (V9)...")
shield_model = load_model(MODEL_V9_PATH, compile=False)
if sentry_model is None:
print(f"Loading Sentry Model (V10)...")
sentry_model = load_model(MODEL_V10_PATH, compile=False)
def highpass_filter(data, cutoff, fs, order=5):
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = butter(order, normal_cutoff, btype='high', analog=False)
return lfilter(b, a, data)
def is_cough_present(y, sr):
"""
Heuristic to detect impulsive sounds and calculate Signal-to-Noise Ratio (SNR).
"""
if len(y) == 0: return False, "No audio signal.", False
# Calculate RMS energy for 100ms frames
hop_length = int(0.1 * sr)
rms = librosa.feature.rms(y=y, frame_length=hop_length*2, hop_length=hop_length)[0]
# 1. Absolute Peak Check
max_energy = np.max(rms)
if max_energy < 0.01:
return False, "Silence detected. Please cough closer to the microphone.", False
# 2. Impulsiveness (Peak-to-Average) and SNR
mean_energy = np.mean(rms)
par = max_energy / (mean_energy + 1e-8)
# SNR Calculation (Peak Signal vs 10th percentile noise floor)
noise_floor = np.percentile(rms, 10) + 1e-8
snr = 20 * np.log10(max_energy / noise_floor)
low_quality = bool(snr < 12) # If SNR is below 12dB, it's considered poor quality/noisy
# 3. Strict Background Noise Lockout
if snr < 8.0:
return False, "Background noise is too loud. Please move to a quieter area and try again.", True
# 4. Strict Impulsiveness Lockout (Rejects trains/traffic/wind)
# Lowered drastically to 1.5 because children or continuous coughing fits
# result in high mean energy, dropping the PAR naturally.
if par < 1.5:
return False, "No clear cough sound detected. Please cough forcefully.", low_quality
return True, "Cough presence detected.", low_quality
def generate_spectrogram_b64(y, sr):
"""
Generates a high-resolution clinical Mel-Spectrogram and returns it as a Base64 string.
"""
try:
fig, ax = plt.subplots(figsize=(6, 2.5), dpi=100)
fig.patch.set_alpha(0.0) # Transparent background
ax.set_axis_off()
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
S_dB = librosa.power_to_db(S, ref=np.max)
img = librosa.display.specshow(S_dB, sr=sr, fmax=8000, ax=ax, cmap='magma')
plt.tight_layout(pad=0)
# Save to memory buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, transparent=True)
buf.seek(0)
encoded = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return encoded
except Exception as e:
print(f"Spectrogram Error: {e}")
return None
@app.route('/')
def index():
return jsonify({
"status": "online",
"service": "KasaHealth Dual-Brain Engine",
"version": "2.1.0 (Smart Sensing)",
"message": "Dual-Brain analysis ready with Cough-Detector."
})
@app.route('/predict', methods=['POST'])
def predict():
if 'audio' not in request.files:
return jsonify({"error": "No audio file"}), 400
file = request.files['audio']
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
load_resources()
y, sr = librosa.load(filepath, sr=16000)
# --- NEW: Cough Presence & Quality Detection ---
valid, msg, is_low_quality = is_cough_present(y, sr)
if not valid:
os.remove(filepath)
return jsonify({"error": msg, "type": "NO_COUGH_DETECTED"}), 400
# Pre-process
y_clean = highpass_filter(y, 100, sr)
y_clean = y_clean / (np.max(np.abs(y_clean)) + 1e-8)
# Generate Visual Signature
spectrogram_b64 = generate_spectrogram_b64(y_clean, sr)
emb = extractor.extract(y_clean)
if emb is None:
os.remove(filepath)
return jsonify({"error": "Failed to extract acoustic features."}), 500
X = emb[np.newaxis, ...]
# --- Dual-Brain Analysis ---
# Ensure training=False to disable any GaussianNoise jitter during inference
p9 = shield_model(X, training=False).numpy()[0][0]
p10 = sentry_model(X, training=False).numpy()[0][0]
# LOGGING for debugging inconsistency/false-positives
print(f"DEBUG | Sample: {filename} | Shield: {p9:.4f} | Sentry: {p10:.4f}")
# --- Symptom Integration ---
import json
symptoms_str = request.form.get('symptoms', '[]')
try:
user_symptoms = json.loads(symptoms_str)
except:
user_symptoms = []
has_critical_symptoms = any(s in user_symptoms for s in ['breathless', 'chest_pain', 'chest_sound'])
has_any_symptoms = len([s for s in user_symptoms if s != 'none']) > 0
# CLINICAL THRESHOLDS (Dynamically adjusted for background noise)
SICK_THRESHOLD_SENTRY = 0.55 # Sentry is sensitive ( catches sick)
SICK_THRESHOLD_SHIELD = 0.70 # Shield is protective (must be VERY sure for High Risk)
# --- NEW: Anti-Noise Hardening ---
# If the recording was taken in a noisy place (like a shop), the AI
# is prone to false positives. We must demand much higher proof.
if is_low_quality:
SICK_THRESHOLD_SENTRY = 0.75
SICK_THRESHOLD_SHIELD = 0.85
# Base Acoustic State Logic
base_high_risk = bool(p10 > SICK_THRESHOLD_SHIELD and p9 > SICK_THRESHOLD_SHIELD)
base_med_risk = bool(p10 > SICK_THRESHOLD_SENTRY and not base_high_risk)
# Apply Clinical Symptom Escalation
is_high_risk = base_high_risk
is_med_risk = base_med_risk
if is_med_risk and has_critical_symptoms:
# Promote Med Risk (Acoustic Anomaly) + Critical Symptom -> High Risk
is_high_risk = True
is_med_risk = False
elif not is_high_risk and not is_med_risk and has_critical_symptoms:
# Promote Low Risk (Clean Acoustic) + Critical Symptom -> Med Risk
is_med_risk = True
if is_high_risk:
# BOTH models agree with high intensity OR Med Risk + Critical Symptoms
final_label = "sick"
confidence = 0.95 if has_critical_symptoms else 0.90
is_inconclusive = False
elif is_med_risk:
# Sentry hears something, but Shield is skeptical (Med Risk)
final_label = "sick"
confidence = 0.70 if has_any_symptoms else 0.60
is_inconclusive = True
else:
# Shield/Sentry both agree it's likely healthy AND no critical symptoms
final_label = "healthy"
confidence = float(1.0 - p10)
# Symptom influence on Healthy confidence
if 'none' in user_symptoms:
confidence = max(0.90, confidence) # Boost confidence if they confirm no symptoms
elif has_any_symptoms:
confidence = min(0.70, confidence) # Lower confidence if they have mild symptoms
if confidence < 0.70: confidence = 0.75 # Minimum UI baseline
is_inconclusive = False
# --- NEW: Supabase Cloud Database Tracking ---
try:
import requests
url = 'https://ysaoyyhzdtaisjadscsm.supabase.co/rest/v1/pilot_logs'
headers = {
'apikey': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InlzYW95eWh6ZHRhaXNqYWRzY3NtIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc3Njg2MTc0MCwiZXhwIjoyMDkyNDM3NzQwfQ.2KHZiarraPohf5WgV1DIx5aZVj2eyYQrFd1lt7g4cU4',
'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InlzYW95eWh6ZHRhaXNqYWRzY3NtIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc3Njg2MTc0MCwiZXhwIjoyMDkyNDM3NzQwfQ.2KHZiarraPohf5WgV1DIx5aZVj2eyYQrFd1lt7g4cU4',
'Content-Type': 'application/json',
'Prefer': 'return=minimal'
}
payload = {
'result': final_label,
'shield_score': round(float(p9), 4),
'sentry_score': round(float(p10), 4),
'confidence': round(confidence, 4)
}
requests.post(url, json=payload, headers=headers, timeout=3)
except Exception as e:
print(f'Supabase Log Error: {e}')
os.remove(filepath)
return jsonify({
"status": "success",
"result": final_label,
"confidence": confidence,
"is_inconclusive": is_inconclusive,
"low_quality": is_low_quality,
"spectrogram_b64": spectrogram_b64,
"scores": { "shield": float(p9), "sentry": float(p10) },
"recommendation": get_recommendation(final_label, is_inconclusive)
})
except Exception as e:
if os.path.exists(filepath): os.remove(filepath)
return jsonify({"error": str(e)}), 500
def get_recommendation(label, is_inconclusive):
# This is currently vestigial as the exact text requested by the user
# will be handled on the frontend via translations (app.js)
if label == "sick":
if is_inconclusive:
return "Moderate match. Consult a doctor."
return "High match. Urgent: Please visit nearest PHC."
return "Low match. No significant markers."
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)