KasaHealth / models /predict_hear.py
78anand's picture
Upload folder using huggingface_hub
f317798 verified
import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
# Add project root to sys.path to allow importing utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.hear_extractor import HeARExtractor
# --- Configuration ---
MODEL_PATH = r"c:\Users\ASUS\lung_ai_project\models\hear_classifier.h5"
CLASSES_PATH = r"c:\Users\ASUS\lung_ai_project\models\hear_classes.npy"
def predict_audio(file_path):
print(f"\nAnalyzing: {os.path.basename(file_path)}")
print("-" * 50)
if not os.path.exists(file_path):
print(f"Error: File not found at {file_path}")
return
# 1. Initialize Extractor
print("Step 1: Initializing HeAR Extractor...")
try:
extractor = HeARExtractor()
except Exception as e:
print(f"Failed to load HeAR model: {e}")
return
# 2. Extract Features
print("Step 2: Extracting HeAR embeddings...")
embedding = extractor.extract(file_path)
if embedding is None:
print("Extraction failed. Check audio format.")
return
# 3. Load Classifier
print("Step 3: Loading Classifier...")
try:
model = load_model(MODEL_PATH)
classes = np.load(CLASSES_PATH)
print(f"Model loaded. Classes: {classes}")
except Exception as e:
print(f"Error loading model: {e}")
return
# 4. Predict
print("Step 4: Running Inference...")
try:
X = embedding[np.newaxis, ...] # Add batch dimension
preds = model.predict(X, verbose=0)
pred_idx = np.argmax(preds[0])
pred_label = classes[pred_idx]
confidence = preds[0][pred_idx]
except Exception as e:
print(f"Error during inference: {e}")
return
print("-" * 50)
print(f"RESULT: {pred_label.upper()}")
print(f"CONFIDENCE: {confidence*100:.2f}%")
print("-" * 50)
# Save to file for easy access
with open(r"c:\Users\ASUS\lung_ai_project\models\last_prediction.txt", "w") as f:
f.write(f"RESULT: {pred_label.upper()}\n")
f.write(f"CONFIDENCE: {confidence*100:.2f}%\n")
# Simple interpretation
if pred_label == "sick":
print("Recommendation: Potential respiratory symptoms detected. Consider medical consultation.")
else:
print("Recommendation: Acoustic pattern appears healthy. Continue monitoring if symptoms persist.")
if __name__ == "__main__":
if len(sys.argv) > 1:
audio_file = sys.argv[1]
else:
# Default for the specific user request
audio_file = r"C:\Users\ASUS\Downloads\WhatsApp Audio 2026-01-15 at 7.26.30 PM.mpeg"
predict_audio(audio_file)