HeartWatchAI / test_inference.py
Ashkan Taghipour (The University of Western Australia)
Initial HeartWatch AI demo release
cd846d7
#!/usr/bin/env python3
"""
Test script for DeepECG Inference Module
=========================================
Run this script to verify the inference engine works correctly.
Usage:
# Set HF_TOKEN environment variable first
export HF_TOKEN='your_huggingface_token'
# Run the test
python test_inference.py
Expected output:
- Models download from HuggingFace Hub
- Dummy signal inference completes
- Results for all 4 models are printed
"""
import os
import sys
import numpy as np
# Ensure HF_TOKEN is set
if not os.environ.get("HF_TOKEN"):
print("ERROR: HF_TOKEN environment variable not set")
print("Please run: export HF_TOKEN='your_token'")
sys.exit(1)
print("=" * 60)
print("DeepECG Inference Test")
print("=" * 60)
# Import the inference module
try:
from inference import DeepECGInference
print("[OK] Import successful")
except ImportError as e:
print(f"[FAIL] Import failed: {e}")
sys.exit(1)
# Create inference engine
try:
engine = DeepECGInference()
print(f"[OK] Engine created with {len(engine.class_names)} class names")
except Exception as e:
print(f"[FAIL] Engine creation failed: {e}")
sys.exit(1)
# Load models
print("\nLoading models from HuggingFace Hub...")
try:
engine.load_models()
print(f"[OK] Loaded {len(engine.models)} models")
for name in engine.models:
print(f" - {name}")
except Exception as e:
print(f"[FAIL] Model loading failed: {e}")
sys.exit(1)
# Test with dummy signal
print("\nTesting inference with dummy signal...")
try:
# Create dummy 10-second ECG (2500 samples at 250Hz, 12 leads)
dummy_signal = np.random.randn(2500, 12).astype(np.float32)
# Run inference
results = engine.predict(dummy_signal)
print(f"[OK] Inference completed in {results['inference_time_ms']:.2f} ms")
print(f"\nResults:")
print(f" - LVEF <= 40%: {results['lvef_40']:.4f}")
print(f" - LVEF < 50%: {results['lvef_50']:.4f}")
print(f" - 5-year AFib: {results['afib_5y']:.4f}")
print(f" - 77-class diagnosis: {len(results['diagnosis_77']['probabilities'])} probabilities")
except Exception as e:
print(f"[FAIL] Inference failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Test top-k predictions
print("\nTop 5 diagnoses:")
try:
top_5 = engine.predict_diagnosis_top_k(dummy_signal, k=5)
for pred in top_5["top_k_predictions"]:
print(f" {pred['class_name']}: {pred['probability']:.4f}")
except Exception as e:
print(f"[FAIL] Top-k prediction failed: {e}")
sys.exit(1)
# Test preprocessing edge cases
print("\nTesting preprocessing with different input shapes...")
try:
# Test (2500, 12) shape
signal_1 = np.random.randn(2500, 12).astype(np.float32)
tensor_1 = engine.preprocess_ecg(signal_1)
assert tensor_1.shape == (1, 12, 2500), f"Expected (1, 12, 2500), got {tensor_1.shape}"
print(f"[OK] Shape (2500, 12) -> {tuple(tensor_1.shape)}")
# Test (12, 2500) shape
signal_2 = np.random.randn(12, 2500).astype(np.float32)
tensor_2 = engine.preprocess_ecg(signal_2)
assert tensor_2.shape == (1, 12, 2500), f"Expected (1, 12, 2500), got {tensor_2.shape}"
print(f"[OK] Shape (12, 2500) -> {tuple(tensor_2.shape)}")
except Exception as e:
print(f"[FAIL] Preprocessing test failed: {e}")
sys.exit(1)
print("\n" + "=" * 60)
print("ALL TESTS PASSED!")
print("=" * 60)