ECG / app.py
IFMedTechdemo's picture
Upload folder using huggingface_hub
5a65830 verified
import gradio as gr
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from huggingface_hub import hf_hub_download
from labels_refined import get_refined_labels, CLASSES
from model import ResNet1d
from dataset import MIMICECGDataset
# --- Configuration ---
DATA_DIR = "./examples"
CSV_PATH = "metadata.csv"
DEVICE = torch.device("cpu")
# --- Load Resources ---
print("Downloading Model from Hub...")
# Get token from Space Secrets (must be set as HF_TOKEN)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
repo_id="IFMedTech/ECG_Model",
filename="resnet_advanced.pth",
token=hf_token
)
print(f"Loading Model from {model_path}...")
model = ResNet1d(num_classes=5).to(DEVICE)
try:
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)
except:
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
print("Loading Dataset Index...")
# Use CSV to dynamically find info for available examples
try:
df = pd.read_csv(CSV_PATH, low_memory=False)
print(f"Loaded CSV with {len(df)} records.")
except Exception as e:
print(f"Error loading CSV: {e}")
df = pd.DataFrame() # Fallback
# Scan examples folder for .dat files
example_files = glob.glob(os.path.join(DATA_DIR, "*.dat"))
available_study_ids = [os.path.splitext(os.path.basename(f))[0] for f in example_files]
print(f"Found examples: {available_study_ids}")
# Build Metadata for Gradio
example_metadata = {}
for sid in available_study_ids:
if df.empty:
example_metadata[sid] = {"diagnosis": "Unknown (CSV Missing)", "text": "N/A"}
continue
row = df[df['study_id'].astype(str) == str(sid)]
if not row.empty:
cols = [c for c in df.columns if 'report_' in c]
lines = [str(row.iloc[0][c]).strip() for c in cols if pd.notna(row.iloc[0][c]) and str(row.iloc[0][c]).strip() != '']
full_text = '\n'.join(lines)
# Simple diagnosis estimation from labels for display title
labels_vec = get_refined_labels(' '.join(lines))
active_classes = [CLASSES[i] for i, val in enumerate(labels_vec) if val == 1.0]
diagnosis = ", ".join(active_classes) if active_classes else "Normal/Other"
example_metadata[sid] = {
"diagnosis": diagnosis,
"text": full_text
}
else:
example_metadata[sid] = {"diagnosis": "Metadata Not Found", "text": "N/A"}
def load_signal(path):
if not os.path.exists(path):
return None
gain = 200.0
with open(path, 'rb') as f:
raw_data = np.fromfile(f, dtype=np.int16)
n_leads = 12
n_samples = 5000
expected_size = n_leads * n_samples
if raw_data.size < expected_size:
padded = np.zeros(expected_size, dtype=np.int16)
padded[:raw_data.size] = raw_data
raw_data = padded
else:
raw_data = raw_data[:expected_size]
signal = raw_data.reshape((n_samples, n_leads)).T
signal = signal.astype(np.float32) / gain
return signal
def plot_ecg(signal, title="12-Lead ECG"):
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
plt.subplots_adjust(hspace=0.2)
for i in range(12):
axes[i].plot(signal[i], color='k', linewidth=0.8)
axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
axes[i].spines['top'].set_visible(False)
axes[i].spines['right'].set_visible(False)
axes[i].spines['bottom'].set_visible(False if i < 11 else True)
axes[i].spines['left'].set_visible(True)
axes[i].grid(True, linestyle='--', alpha=0.5)
axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
fig.suptitle(title, fontsize=16, y=0.90)
return fig
def predict_ecg(study_id):
path = os.path.join(DATA_DIR, f"{study_id}.dat")
if not os.path.exists(path):
return None, f"File not found for study {study_id}", {}
signal = load_signal(path)
if signal is None:
return None, "Error loading signal", {}
fig = plot_ecg(signal, title=f"Study {study_id}")
tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(tensor_sig)
probs = torch.sigmoid(logits).cpu().numpy()[0]
results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
return fig, results, full_text
# --- Gradio UI ---
examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
examples.sort(key=lambda x: x[0])
example_ids = [k[0] for k in examples]
if not example_ids:
example_ids = ["No Examples Found"]
with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
gr.Markdown("Select a study ID from the examples below to analyze the 12-lead ECG.")
with gr.Row():
with gr.Column(scale=1):
study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0] if example_ids else None)
gr.Markdown("### Example Descriptions")
gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
analyze_btn = gr.Button("Analyze ECG", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot(label="12-Lead ECG Visualization")
label_output = gr.Label(label="AI Predictions")
text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
analyze_btn.click(
fn=predict_ecg,
inputs=[study_input],
outputs=[plot_output, label_output, text_output]
)
if __name__ == "__main__":
demo.launch()