Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import os | |
| import glob | |
| from huggingface_hub import hf_hub_download | |
| from labels_refined import get_refined_labels, CLASSES | |
| from model import ResNet1d | |
| from dataset import MIMICECGDataset | |
| # --- Configuration --- | |
| DATA_DIR = "./examples" | |
| CSV_PATH = "metadata.csv" | |
| DEVICE = torch.device("cpu") | |
| # --- Load Resources --- | |
| print("Downloading Model from Hub...") | |
| # Get token from Space Secrets (must be set as HF_TOKEN) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| model_path = hf_hub_download( | |
| repo_id="IFMedTech/ECG_Model", | |
| filename="resnet_advanced.pth", | |
| token=hf_token | |
| ) | |
| print(f"Loading Model from {model_path}...") | |
| model = ResNet1d(num_classes=5).to(DEVICE) | |
| try: | |
| state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True) | |
| except: | |
| state_dict = torch.load(model_path, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print("Loading Dataset Index...") | |
| # Use CSV to dynamically find info for available examples | |
| try: | |
| df = pd.read_csv(CSV_PATH, low_memory=False) | |
| print(f"Loaded CSV with {len(df)} records.") | |
| except Exception as e: | |
| print(f"Error loading CSV: {e}") | |
| df = pd.DataFrame() # Fallback | |
| # Scan examples folder for .dat files | |
| example_files = glob.glob(os.path.join(DATA_DIR, "*.dat")) | |
| available_study_ids = [os.path.splitext(os.path.basename(f))[0] for f in example_files] | |
| print(f"Found examples: {available_study_ids}") | |
| # Build Metadata for Gradio | |
| example_metadata = {} | |
| for sid in available_study_ids: | |
| if df.empty: | |
| example_metadata[sid] = {"diagnosis": "Unknown (CSV Missing)", "text": "N/A"} | |
| continue | |
| row = df[df['study_id'].astype(str) == str(sid)] | |
| if not row.empty: | |
| cols = [c for c in df.columns if 'report_' in c] | |
| lines = [str(row.iloc[0][c]).strip() for c in cols if pd.notna(row.iloc[0][c]) and str(row.iloc[0][c]).strip() != ''] | |
| full_text = '\n'.join(lines) | |
| # Simple diagnosis estimation from labels for display title | |
| labels_vec = get_refined_labels(' '.join(lines)) | |
| active_classes = [CLASSES[i] for i, val in enumerate(labels_vec) if val == 1.0] | |
| diagnosis = ", ".join(active_classes) if active_classes else "Normal/Other" | |
| example_metadata[sid] = { | |
| "diagnosis": diagnosis, | |
| "text": full_text | |
| } | |
| else: | |
| example_metadata[sid] = {"diagnosis": "Metadata Not Found", "text": "N/A"} | |
| def load_signal(path): | |
| if not os.path.exists(path): | |
| return None | |
| gain = 200.0 | |
| with open(path, 'rb') as f: | |
| raw_data = np.fromfile(f, dtype=np.int16) | |
| n_leads = 12 | |
| n_samples = 5000 | |
| expected_size = n_leads * n_samples | |
| if raw_data.size < expected_size: | |
| padded = np.zeros(expected_size, dtype=np.int16) | |
| padded[:raw_data.size] = raw_data | |
| raw_data = padded | |
| else: | |
| raw_data = raw_data[:expected_size] | |
| signal = raw_data.reshape((n_samples, n_leads)).T | |
| signal = signal.astype(np.float32) / gain | |
| return signal | |
| def plot_ecg(signal, title="12-Lead ECG"): | |
| leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
| fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True) | |
| plt.subplots_adjust(hspace=0.2) | |
| for i in range(12): | |
| axes[i].plot(signal[i], color='k', linewidth=0.8) | |
| axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold') | |
| axes[i].spines['top'].set_visible(False) | |
| axes[i].spines['right'].set_visible(False) | |
| axes[i].spines['bottom'].set_visible(False if i < 11 else True) | |
| axes[i].spines['left'].set_visible(True) | |
| axes[i].grid(True, linestyle='--', alpha=0.5) | |
| axes[11].set_xlabel("Samples (500Hz)", fontsize=12) | |
| fig.suptitle(title, fontsize=16, y=0.90) | |
| return fig | |
| def predict_ecg(study_id): | |
| path = os.path.join(DATA_DIR, f"{study_id}.dat") | |
| if not os.path.exists(path): | |
| return None, f"File not found for study {study_id}", {} | |
| signal = load_signal(path) | |
| if signal is None: | |
| return None, "Error loading signal", {} | |
| fig = plot_ecg(signal, title=f"Study {study_id}") | |
| tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(tensor_sig) | |
| probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))} | |
| full_text = example_metadata.get(study_id, {}).get("text", "Unknown") | |
| return fig, results, full_text | |
| # --- Gradio UI --- | |
| examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()] | |
| examples.sort(key=lambda x: x[0]) | |
| example_ids = [k[0] for k in examples] | |
| if not example_ids: | |
| example_ids = ["No Examples Found"] | |
| with gr.Blocks(title="ECG Arrhythmia Classifier") as demo: | |
| gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier") | |
| gr.Markdown("Select a study ID from the examples below to analyze the 12-lead ECG.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0] if example_ids else None) | |
| gr.Markdown("### Example Descriptions") | |
| gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False) | |
| analyze_btn = gr.Button("Analyze ECG", variant="primary") | |
| with gr.Column(scale=2): | |
| plot_output = gr.Plot(label="12-Lead ECG Visualization") | |
| label_output = gr.Label(label="AI Predictions") | |
| text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5) | |
| analyze_btn.click( | |
| fn=predict_ecg, | |
| inputs=[study_input], | |
| outputs=[plot_output, label_output, text_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |