Spaces:
Running
Running
Ashkan Taghipour (The University of Western Australia)
Fix model repository names for LVEF models
2a47411 | """ | |
| DeepECG Inference Module for HeartWatch AI | |
| =========================================== | |
| This module provides CPU-optimized inference for 4 EfficientNetV2 models: | |
| - 77-class ECG diagnosis | |
| - LVEF <= 40% prediction | |
| - LVEF < 50% prediction | |
| - 5-year AFib risk prediction | |
| The preprocessing exactly replicates DeepECG's pipeline: | |
| 1. Load signal as (samples, leads) = (2500, 12) | |
| 2. Transpose to (leads, samples) = (12, 2500) | |
| 3. Apply MHI factor scaling: signal *= (1/0.0048) | |
| 4. Apply sigmoid to model logits | |
| Models are downloaded from HuggingFace Hub using HF_TOKEN from environment. | |
| """ | |
| import os | |
| import json | |
| import time | |
| import logging | |
| from typing import Dict, Optional, Any, Union | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # CPU optimizations for HuggingFace Spaces (no GPU) | |
| torch.set_num_threads(2) | |
| torch.set_flush_denormal(True) | |
| class DeepECGInference: | |
| """ | |
| CPU-optimized inference engine for DeepECG EfficientNetV2 models. | |
| Loads 4 models from HuggingFace Hub: | |
| - heartwise/EfficientNetV2_77_Classes: 77-class ECG diagnosis | |
| - heartwise/EfficientNetV2_LVEF_40: LVEF <= 40% prediction | |
| - heartwise/EfficientNetV2_LVEF_50: LVEF < 50% prediction | |
| - heartwise/EfficientNetV2_AFIB_5y: 5-year AFib risk prediction | |
| Attributes: | |
| device: Always CPU for HF Spaces | |
| models: Dict containing loaded TorchScript models | |
| class_names: List of 77 ECG diagnosis class names | |
| mhi_factor: Scaling factor for signal preprocessing (1/0.0048) | |
| """ | |
| # Model repository mappings | |
| MODEL_REPOS = { | |
| "diagnosis_77": "heartwise/EfficientNetV2_77_Classes", | |
| "lvef_40": "heartwise/EfficientNetV2_LVEF_equal_under_40", | |
| "lvef_50": "heartwise/EfficientNetV2_LVEF_under_50", | |
| "afib_5y": "heartwise/EfficientNetV2_AFIB_5y", | |
| } | |
| # Expected input specifications | |
| EXPECTED_LEADS = 12 | |
| EXPECTED_SAMPLES = 2500 # 10 seconds at 250 Hz | |
| SAMPLING_RATE = 250 # Hz | |
| # Preprocessing constants from DeepECG | |
| MHI_FACTOR = 1 / 0.0048 # ~208.33 | |
| def __init__(self, cache_dir: Optional[str] = None): | |
| """ | |
| Initialize the inference engine. | |
| Args: | |
| cache_dir: Directory to cache downloaded models. | |
| Defaults to ./weights | |
| """ | |
| self.device = torch.device("cpu") | |
| self.cache_dir = cache_dir or os.path.join(os.getcwd(), "weights") | |
| self.models: Dict[str, torch.jit.ScriptModule] = {} | |
| self.class_names: list = [] | |
| self._load_class_names() | |
| def _load_class_names(self) -> None: | |
| """Load the 77 ECG class names from class_names.json.""" | |
| class_names_path = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), | |
| "class_names.json" | |
| ) | |
| try: | |
| with open(class_names_path, "r") as f: | |
| self.class_names = json.load(f) | |
| logger.info(f"Loaded {len(self.class_names)} class names") | |
| except FileNotFoundError: | |
| logger.warning(f"class_names.json not found at {class_names_path}") | |
| self.class_names = [] | |
| def _get_hf_token(self) -> Optional[str]: | |
| """Get HuggingFace token from environment variable.""" | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| logger.warning("HF_TOKEN environment variable not set") | |
| return token | |
| def _download_model(self, repo_id: str, model_name: str) -> str: | |
| """ | |
| Download model from HuggingFace Hub. | |
| Args: | |
| repo_id: HuggingFace repository ID | |
| model_name: Local name for the model | |
| Returns: | |
| Path to the downloaded model directory | |
| """ | |
| local_dir = os.path.join(self.cache_dir, model_name) | |
| if os.path.exists(local_dir): | |
| logger.info(f"Model {model_name} already cached at {local_dir}") | |
| return local_dir | |
| logger.info(f"Downloading {repo_id} to {local_dir}") | |
| os.makedirs(local_dir, exist_ok=True) | |
| hf_token = self._get_hf_token() | |
| local_dir = snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| repo_type="model", | |
| token=hf_token | |
| ) | |
| logger.info(f"Downloaded {repo_id} to {local_dir}") | |
| return local_dir | |
| def _load_model_from_dir(self, model_dir: str) -> torch.jit.ScriptModule: | |
| """ | |
| Load a TorchScript model from a directory. | |
| Args: | |
| model_dir: Directory containing the .pt file | |
| Returns: | |
| Loaded TorchScript model | |
| Raises: | |
| ValueError: If no .pt file is found in the directory | |
| """ | |
| pt_file = next( | |
| (f for f in os.listdir(model_dir) if f.endswith('.pt')), | |
| None | |
| ) | |
| if not pt_file: | |
| raise ValueError(f"No .pt file found in {model_dir}") | |
| model_path = os.path.join(model_dir, pt_file) | |
| model = torch.jit.load(model_path, map_location=self.device) | |
| model.eval() | |
| return model | |
| def load_models(self) -> None: | |
| """ | |
| Download and load all 4 models from HuggingFace Hub. | |
| Uses HF_TOKEN from os.environ for authentication. | |
| Models are loaded in eval mode on CPU. | |
| """ | |
| logger.info("Loading DeepECG models...") | |
| for model_key, repo_id in self.MODEL_REPOS.items(): | |
| try: | |
| model_dir = self._download_model(repo_id, model_key) | |
| self.models[model_key] = self._load_model_from_dir(model_dir) | |
| logger.info(f"Loaded model: {model_key} from {repo_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load {model_key}: {e}") | |
| raise | |
| logger.info(f"Successfully loaded {len(self.models)} models") | |
| def preprocess_ecg( | |
| self, | |
| ecg_signal: Union[np.ndarray, torch.Tensor] | |
| ) -> torch.Tensor: | |
| """ | |
| Preprocess ECG signal to match DeepECG's exact preprocessing. | |
| The preprocessing pipeline: | |
| 1. Ensure signal is numpy array with correct shape | |
| 2. Handle shape: expect (samples, leads) = (2500, 12) or (12, 2500) | |
| 3. Transpose to (leads, samples) = (12, 2500) if needed | |
| 4. Convert to float32 tensor | |
| 5. Add batch dimension: (1, 12, 2500) | |
| 6. Apply MHI factor scaling: signal *= (1/0.0048) | |
| Args: | |
| ecg_signal: Raw ECG signal, shape (samples, leads) or (leads, samples) | |
| Expected: 12 leads, 2500 samples (10s at 250Hz) | |
| Returns: | |
| Preprocessed tensor ready for model inference, shape (1, 12, 2500) | |
| Raises: | |
| ValueError: If signal shape is invalid | |
| """ | |
| # Convert to numpy if tensor | |
| if isinstance(ecg_signal, torch.Tensor): | |
| ecg_signal = ecg_signal.numpy() | |
| # Ensure float32 | |
| ecg_signal = ecg_signal.astype(np.float32) | |
| # Handle shape - expect (samples, leads) = (2500, 12) or (12, 2500) | |
| if ecg_signal.ndim != 2: | |
| raise ValueError( | |
| f"Expected 2D signal, got shape {ecg_signal.shape}" | |
| ) | |
| # Determine orientation and transpose if needed | |
| # If shape is (samples, leads) = (2500, 12), transpose to (12, 2500) | |
| # If shape is (12, 2500), it's already correct | |
| if ecg_signal.shape[0] == self.EXPECTED_SAMPLES and ecg_signal.shape[1] == self.EXPECTED_LEADS: | |
| # Shape is (2500, 12) -> transpose to (12, 2500) | |
| ecg_signal = ecg_signal.T | |
| elif ecg_signal.shape[0] == self.EXPECTED_LEADS and ecg_signal.shape[1] == self.EXPECTED_SAMPLES: | |
| # Shape is already (12, 2500) | |
| pass | |
| else: | |
| # Try to infer orientation | |
| if ecg_signal.shape[1] == self.EXPECTED_LEADS: | |
| ecg_signal = ecg_signal.T | |
| elif ecg_signal.shape[0] != self.EXPECTED_LEADS: | |
| raise ValueError( | |
| f"Invalid signal shape {ecg_signal.shape}. " | |
| f"Expected (2500, 12) or (12, 2500)" | |
| ) | |
| # Verify final shape | |
| if ecg_signal.shape[0] != self.EXPECTED_LEADS: | |
| raise ValueError( | |
| f"Signal must have {self.EXPECTED_LEADS} leads, " | |
| f"got {ecg_signal.shape[0]}" | |
| ) | |
| # Convert to tensor and add batch dimension | |
| signal_tensor = torch.from_numpy(ecg_signal).float() | |
| signal_tensor = signal_tensor.unsqueeze(0) # (1, 12, samples) | |
| # Move to device (CPU) | |
| signal_tensor = signal_tensor.to(self.device) | |
| # Apply MHI factor scaling (this is done in model __call__ in DeepECG) | |
| signal_tensor = signal_tensor * self.MHI_FACTOR | |
| return signal_tensor | |
| def predict( | |
| self, | |
| ecg_signal: Union[np.ndarray, torch.Tensor] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run inference on an ECG signal using all 4 models. | |
| Args: | |
| ecg_signal: Raw ECG signal, shape (samples, leads) or (leads, samples) | |
| Expected: 12 leads, 2500 samples (10s at 250Hz) | |
| Returns: | |
| Dictionary containing: | |
| - diagnosis_77: Dict with 'probabilities' (77 floats) and 'class_names' | |
| - lvef_40: Probability of LVEF <= 40% | |
| - lvef_50: Probability of LVEF < 50% | |
| - afib_5y: Probability of AFib within 5 years | |
| - inference_time_ms: Total inference time in milliseconds | |
| """ | |
| if not self.models: | |
| raise RuntimeError("Models not loaded. Call load_models() first.") | |
| start_time = time.time() | |
| # Preprocess the signal | |
| signal_tensor = self.preprocess_ecg(ecg_signal) | |
| results = {} | |
| with torch.no_grad(): | |
| # 77-class diagnosis | |
| if "diagnosis_77" in self.models: | |
| logits = self.models["diagnosis_77"](signal_tensor) | |
| probs = torch.sigmoid(logits) | |
| probs_list = probs.squeeze().cpu().numpy().tolist() | |
| results["diagnosis_77"] = { | |
| "probabilities": probs_list, | |
| "class_names": self.class_names if self.class_names else None, | |
| } | |
| # LVEF <= 40% | |
| if "lvef_40" in self.models: | |
| logits = self.models["lvef_40"](signal_tensor) | |
| prob = torch.sigmoid(logits) | |
| results["lvef_40"] = float(prob.squeeze().cpu().numpy()) | |
| # LVEF < 50% | |
| if "lvef_50" in self.models: | |
| logits = self.models["lvef_50"](signal_tensor) | |
| prob = torch.sigmoid(logits) | |
| results["lvef_50"] = float(prob.squeeze().cpu().numpy()) | |
| # 5-year AFib risk | |
| if "afib_5y" in self.models: | |
| logits = self.models["afib_5y"](signal_tensor) | |
| prob = torch.sigmoid(logits) | |
| results["afib_5y"] = float(prob.squeeze().cpu().numpy()) | |
| end_time = time.time() | |
| results["inference_time_ms"] = (end_time - start_time) * 1000 | |
| return results | |
| def predict_diagnosis_top_k( | |
| self, | |
| ecg_signal: Union[np.ndarray, torch.Tensor], | |
| k: int = 5 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get top-k diagnoses from the 77-class model. | |
| Args: | |
| ecg_signal: Raw ECG signal | |
| k: Number of top predictions to return | |
| Returns: | |
| Dictionary with top-k predictions sorted by probability | |
| """ | |
| results = self.predict(ecg_signal) | |
| if "diagnosis_77" not in results: | |
| raise RuntimeError("77-class diagnosis model not loaded") | |
| probs = results["diagnosis_77"]["probabilities"] | |
| class_names = results["diagnosis_77"]["class_names"] or [f"Class_{i}" for i in range(77)] | |
| # Get top-k indices | |
| top_k_indices = np.argsort(probs)[::-1][:k] | |
| top_k_predictions = [ | |
| { | |
| "class_name": class_names[idx], | |
| "probability": probs[idx], | |
| "class_index": int(idx) | |
| } | |
| for idx in top_k_indices | |
| ] | |
| return { | |
| "top_k_predictions": top_k_predictions, | |
| "inference_time_ms": results["inference_time_ms"] | |
| } | |
| def get_inference_engine(cache_dir: Optional[str] = None) -> DeepECGInference: | |
| """ | |
| Factory function to create and initialize a DeepECGInference instance. | |
| Args: | |
| cache_dir: Optional directory to cache models | |
| Returns: | |
| Initialized DeepECGInference with models loaded | |
| """ | |
| engine = DeepECGInference(cache_dir=cache_dir) | |
| engine.load_models() | |
| return engine | |
| if __name__ == "__main__": | |
| # Example usage / testing | |
| print("DeepECG Inference Module") | |
| print("=" * 50) | |
| # Create inference engine | |
| engine = DeepECGInference() | |
| # Load models (requires HF_TOKEN environment variable) | |
| try: | |
| engine.load_models() | |
| print("Models loaded successfully!") | |
| # Create dummy signal for testing | |
| dummy_signal = np.random.randn(2500, 12).astype(np.float32) | |
| # Run inference | |
| results = engine.predict(dummy_signal) | |
| print(f"\nInference time: {results['inference_time_ms']:.2f} ms") | |
| print(f"LVEF <= 40%: {results['lvef_40']:.4f}") | |
| print(f"LVEF < 50%: {results['lvef_50']:.4f}") | |
| print(f"5-year AFib risk: {results['afib_5y']:.4f}") | |
| print(f"77-class diagnosis: {len(results['diagnosis_77']['probabilities'])} classes") | |
| # Get top-5 diagnoses | |
| top_5 = engine.predict_diagnosis_top_k(dummy_signal, k=5) | |
| print("\nTop 5 diagnoses:") | |
| for pred in top_5["top_k_predictions"]: | |
| print(f" {pred['class_name']}: {pred['probability']:.4f}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| print("\nMake sure HF_TOKEN environment variable is set:") | |
| print(" export HF_TOKEN='your_huggingface_token'") | |