Spaces:
Running
Running
File size: 6,206 Bytes
e3b4744 57892d7 e3b4744 5a65830 57892d7 e3b4744 57892d7 5a65830 57892d7 e3b4744 57892d7 e3b4744 57892d7 e3b4744 57892d7 e3b4744 57892d7 e3b4744 57892d7 e3b4744 57892d7 e3b4744 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import gradio as gr
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from huggingface_hub import hf_hub_download
from labels_refined import get_refined_labels, CLASSES
from model import ResNet1d
from dataset import MIMICECGDataset
# --- Configuration ---
DATA_DIR = "./examples"
CSV_PATH = "metadata.csv"
DEVICE = torch.device("cpu")
# --- Load Resources ---
print("Downloading Model from Hub...")
# Get token from Space Secrets (must be set as HF_TOKEN)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
repo_id="IFMedTech/ECG_Model",
filename="resnet_advanced.pth",
token=hf_token
)
print(f"Loading Model from {model_path}...")
model = ResNet1d(num_classes=5).to(DEVICE)
try:
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)
except:
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
print("Loading Dataset Index...")
# Use CSV to dynamically find info for available examples
try:
df = pd.read_csv(CSV_PATH, low_memory=False)
print(f"Loaded CSV with {len(df)} records.")
except Exception as e:
print(f"Error loading CSV: {e}")
df = pd.DataFrame() # Fallback
# Scan examples folder for .dat files
example_files = glob.glob(os.path.join(DATA_DIR, "*.dat"))
available_study_ids = [os.path.splitext(os.path.basename(f))[0] for f in example_files]
print(f"Found examples: {available_study_ids}")
# Build Metadata for Gradio
example_metadata = {}
for sid in available_study_ids:
if df.empty:
example_metadata[sid] = {"diagnosis": "Unknown (CSV Missing)", "text": "N/A"}
continue
row = df[df['study_id'].astype(str) == str(sid)]
if not row.empty:
cols = [c for c in df.columns if 'report_' in c]
lines = [str(row.iloc[0][c]).strip() for c in cols if pd.notna(row.iloc[0][c]) and str(row.iloc[0][c]).strip() != '']
full_text = '\n'.join(lines)
# Simple diagnosis estimation from labels for display title
labels_vec = get_refined_labels(' '.join(lines))
active_classes = [CLASSES[i] for i, val in enumerate(labels_vec) if val == 1.0]
diagnosis = ", ".join(active_classes) if active_classes else "Normal/Other"
example_metadata[sid] = {
"diagnosis": diagnosis,
"text": full_text
}
else:
example_metadata[sid] = {"diagnosis": "Metadata Not Found", "text": "N/A"}
def load_signal(path):
if not os.path.exists(path):
return None
gain = 200.0
with open(path, 'rb') as f:
raw_data = np.fromfile(f, dtype=np.int16)
n_leads = 12
n_samples = 5000
expected_size = n_leads * n_samples
if raw_data.size < expected_size:
padded = np.zeros(expected_size, dtype=np.int16)
padded[:raw_data.size] = raw_data
raw_data = padded
else:
raw_data = raw_data[:expected_size]
signal = raw_data.reshape((n_samples, n_leads)).T
signal = signal.astype(np.float32) / gain
return signal
def plot_ecg(signal, title="12-Lead ECG"):
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
plt.subplots_adjust(hspace=0.2)
for i in range(12):
axes[i].plot(signal[i], color='k', linewidth=0.8)
axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
axes[i].spines['top'].set_visible(False)
axes[i].spines['right'].set_visible(False)
axes[i].spines['bottom'].set_visible(False if i < 11 else True)
axes[i].spines['left'].set_visible(True)
axes[i].grid(True, linestyle='--', alpha=0.5)
axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
fig.suptitle(title, fontsize=16, y=0.90)
return fig
def predict_ecg(study_id):
path = os.path.join(DATA_DIR, f"{study_id}.dat")
if not os.path.exists(path):
return None, f"File not found for study {study_id}", {}
signal = load_signal(path)
if signal is None:
return None, "Error loading signal", {}
fig = plot_ecg(signal, title=f"Study {study_id}")
tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(tensor_sig)
probs = torch.sigmoid(logits).cpu().numpy()[0]
results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
return fig, results, full_text
# --- Gradio UI ---
examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
examples.sort(key=lambda x: x[0])
example_ids = [k[0] for k in examples]
if not example_ids:
example_ids = ["No Examples Found"]
with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
gr.Markdown("Select a study ID from the examples below to analyze the 12-lead ECG.")
with gr.Row():
with gr.Column(scale=1):
study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0] if example_ids else None)
gr.Markdown("### Example Descriptions")
gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
analyze_btn = gr.Button("Analyze ECG", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot(label="12-Lead ECG Visualization")
label_output = gr.Label(label="AI Predictions")
text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
analyze_btn.click(
fn=predict_ecg,
inputs=[study_input],
outputs=[plot_output, label_output, text_output]
)
if __name__ == "__main__":
demo.launch()
|