Spaces:
Running
Running
| import os | |
| import sys | |
| import numpy as np | |
| import librosa | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| # Paths | |
| HEAR_MODEL_PATH = r"c:\Users\ASUS\lung_ai_project\models\hear_classifier_augmented.h5" | |
| HEAR_CLASSES_PATH = r"c:\Users\ASUS\lung_ai_project\models\hear_aug_classes.npy" | |
| CNN_MODEL_PATH = r"c:\Users\ASUS\lung_ai_project\models\cough_model.h5" | |
| CNN_CLASSES_PATH = r"c:\Users\ASUS\lung_ai_project\models\classes.npy" | |
| # Configuration for CNN | |
| CNN_SR = 22050 | |
| CNN_DURATION = 5 | |
| CNN_MFCC = 13 | |
| CNN_MAX_LEN = int(CNN_SR * CNN_DURATION) | |
| # Configuration for HeAR | |
| HEAR_SR = 16000 | |
| class EnsemblePredictor: | |
| def __init__(self): | |
| print("Initializing Ensemble Model...") | |
| # 1. Load HeAR components | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils")) | |
| from hear_extractor import HeARExtractor | |
| self.hear_extractor = HeARExtractor() | |
| if os.path.exists(HEAR_MODEL_PATH): | |
| self.hear_model = load_model(HEAR_MODEL_PATH) | |
| self.hear_classes = np.load(HEAR_CLASSES_PATH) | |
| else: | |
| print("Warning: Augmented HeAR model not found. Using baseline if available.") | |
| # Fallback to non-augmented | |
| alt_path = HEAR_MODEL_PATH.replace("_augmented", "") | |
| if os.path.exists(alt_path): | |
| self.hear_model = load_model(alt_path) | |
| self.hear_classes = np.load(r"c:\Users\ASUS\lung_ai_project\models\hear_classes.npy") | |
| # 2. Load CNN components | |
| self.cnn_model = load_model(CNN_MODEL_PATH) | |
| self.cnn_classes = np.load(CNN_CLASSES_PATH) | |
| def _extract_cnn_features(self, file_path): | |
| audio, sr = librosa.load(file_path, sr=CNN_SR, duration=CNN_DURATION) | |
| if len(audio) < CNN_MAX_LEN: | |
| padding = CNN_MAX_LEN - len(audio) | |
| audio = np.pad(audio, (0, padding), 'constant') | |
| else: | |
| audio = audio[:CNN_MAX_LEN] | |
| mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=CNN_MFCC) | |
| return mfccs[..., np.newaxis] | |
| def predict(self, file_path): | |
| print(f"\nEnsemble Inference for: {os.path.basename(file_path)}") | |
| # 1. HeAR Prediction | |
| emb = self.hear_extractor.extract(file_path) | |
| hear_preds = self.hear_model.predict(emb[np.newaxis, ...], verbose=0)[0] | |
| hear_label = self.hear_classes[np.argmax(hear_preds)] | |
| hear_conf = np.max(hear_preds) | |
| # 2. CNN Prediction | |
| cnn_feat = self._extract_cnn_features(file_path) | |
| cnn_preds = self.cnn_model.predict(cnn_feat[np.newaxis, ...], verbose=0)[0] | |
| cnn_label = self.cnn_classes[np.argmax(cnn_preds)] | |
| cnn_conf = np.max(cnn_preds) | |
| # 3. Ensemble Logic (Weighted Voting) | |
| # We give more weight to HeAR for "Sick" detection and CNN for "Healthy" detection | |
| # based on our previous comprehensive test analysis. | |
| combined_sick_prob = (0.7 * hear_preds[np.where(self.hear_classes == 'sick')[0][0]] + | |
| 0.3 * cnn_preds[np.where(self.cnn_classes == 'sick')[0][0]]) | |
| final_label = "sick" if combined_sick_prob > 0.5 else "healthy" | |
| final_conf = combined_sick_prob if final_label == "sick" else (1 - combined_sick_prob) | |
| return { | |
| "final_result": final_label, | |
| "final_confidence": final_conf, | |
| "hear_result": hear_label, | |
| "hear_conf": hear_conf, | |
| "cnn_result": cnn_label, | |
| "cnn_conf": cnn_conf | |
| } | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 1: | |
| test_file = sys.argv[1] | |
| predictor = EnsemblePredictor() | |
| res = predictor.predict(test_file) | |
| print("\n" + "="*40) | |
| print(f"FINAL RESULT: {res['final_result'].upper()}") | |
| print(f"Confidence: {res['final_confidence']*100:.2f}%") | |
| print("="*40) | |
| print(f"HeAR says: {res['hear_result']} ({res['hear_conf']*100:.1f}%)") | |
| print(f"CNN says: {res['cnn_result']} ({res['cnn_conf']*100:.1f}%)") | |