KasaHealth / models /ensemble_predict.py
78anand's picture
Upload folder using huggingface_hub
f317798 verified
import os
import sys
import numpy as np
import librosa
import tensorflow as tf
from tensorflow.keras.models import load_model
# Paths
HEAR_MODEL_PATH = r"c:\Users\ASUS\lung_ai_project\models\hear_classifier_augmented.h5"
HEAR_CLASSES_PATH = r"c:\Users\ASUS\lung_ai_project\models\hear_aug_classes.npy"
CNN_MODEL_PATH = r"c:\Users\ASUS\lung_ai_project\models\cough_model.h5"
CNN_CLASSES_PATH = r"c:\Users\ASUS\lung_ai_project\models\classes.npy"
# Configuration for CNN
CNN_SR = 22050
CNN_DURATION = 5
CNN_MFCC = 13
CNN_MAX_LEN = int(CNN_SR * CNN_DURATION)
# Configuration for HeAR
HEAR_SR = 16000
class EnsemblePredictor:
def __init__(self):
print("Initializing Ensemble Model...")
# 1. Load HeAR components
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils"))
from hear_extractor import HeARExtractor
self.hear_extractor = HeARExtractor()
if os.path.exists(HEAR_MODEL_PATH):
self.hear_model = load_model(HEAR_MODEL_PATH)
self.hear_classes = np.load(HEAR_CLASSES_PATH)
else:
print("Warning: Augmented HeAR model not found. Using baseline if available.")
# Fallback to non-augmented
alt_path = HEAR_MODEL_PATH.replace("_augmented", "")
if os.path.exists(alt_path):
self.hear_model = load_model(alt_path)
self.hear_classes = np.load(r"c:\Users\ASUS\lung_ai_project\models\hear_classes.npy")
# 2. Load CNN components
self.cnn_model = load_model(CNN_MODEL_PATH)
self.cnn_classes = np.load(CNN_CLASSES_PATH)
def _extract_cnn_features(self, file_path):
audio, sr = librosa.load(file_path, sr=CNN_SR, duration=CNN_DURATION)
if len(audio) < CNN_MAX_LEN:
padding = CNN_MAX_LEN - len(audio)
audio = np.pad(audio, (0, padding), 'constant')
else:
audio = audio[:CNN_MAX_LEN]
mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=CNN_MFCC)
return mfccs[..., np.newaxis]
def predict(self, file_path):
print(f"\nEnsemble Inference for: {os.path.basename(file_path)}")
# 1. HeAR Prediction
emb = self.hear_extractor.extract(file_path)
hear_preds = self.hear_model.predict(emb[np.newaxis, ...], verbose=0)[0]
hear_label = self.hear_classes[np.argmax(hear_preds)]
hear_conf = np.max(hear_preds)
# 2. CNN Prediction
cnn_feat = self._extract_cnn_features(file_path)
cnn_preds = self.cnn_model.predict(cnn_feat[np.newaxis, ...], verbose=0)[0]
cnn_label = self.cnn_classes[np.argmax(cnn_preds)]
cnn_conf = np.max(cnn_preds)
# 3. Ensemble Logic (Weighted Voting)
# We give more weight to HeAR for "Sick" detection and CNN for "Healthy" detection
# based on our previous comprehensive test analysis.
combined_sick_prob = (0.7 * hear_preds[np.where(self.hear_classes == 'sick')[0][0]] +
0.3 * cnn_preds[np.where(self.cnn_classes == 'sick')[0][0]])
final_label = "sick" if combined_sick_prob > 0.5 else "healthy"
final_conf = combined_sick_prob if final_label == "sick" else (1 - combined_sick_prob)
return {
"final_result": final_label,
"final_confidence": final_conf,
"hear_result": hear_label,
"hear_conf": hear_conf,
"cnn_result": cnn_label,
"cnn_conf": cnn_conf
}
if __name__ == "__main__":
if len(sys.argv) > 1:
test_file = sys.argv[1]
predictor = EnsemblePredictor()
res = predictor.predict(test_file)
print("\n" + "="*40)
print(f"FINAL RESULT: {res['final_result'].upper()}")
print(f"Confidence: {res['final_confidence']*100:.2f}%")
print("="*40)
print(f"HeAR says: {res['hear_result']} ({res['hear_conf']*100:.1f}%)")
print(f"CNN says: {res['cnn_result']} ({res['cnn_conf']*100:.1f}%)")