File size: 6,206 Bytes
e3b4744
 
 
 
 
 
 
 
57892d7
e3b4744
 
 
 
 
 
5a65830
57892d7
e3b4744
 
57892d7
5a65830
 
 
 
 
 
 
57892d7
 
e3b4744
 
57892d7
e3b4744
57892d7
e3b4744
 
 
57892d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3b4744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57892d7
e3b4744
 
 
 
 
 
 
 
 
 
 
 
57892d7
 
 
 
 
e3b4744
 
 
 
 
 
 
57892d7
e3b4744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from huggingface_hub import hf_hub_download 
from labels_refined import get_refined_labels, CLASSES
from model import ResNet1d
from dataset import MIMICECGDataset

# --- Configuration ---
DATA_DIR = "./examples" 
CSV_PATH = "metadata.csv" 
DEVICE = torch.device("cpu") 

# --- Load Resources ---
print("Downloading Model from Hub...")
# Get token from Space Secrets (must be set as HF_TOKEN)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
    repo_id="IFMedTech/ECG_Model", 
    filename="resnet_advanced.pth",
    token=hf_token
)

print(f"Loading Model from {model_path}...")
model = ResNet1d(num_classes=5).to(DEVICE)
try:
    state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)
except:
    state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

print("Loading Dataset Index...")
# Use CSV to dynamically find info for available examples
try:
    df = pd.read_csv(CSV_PATH, low_memory=False)
    print(f"Loaded CSV with {len(df)} records.")
except Exception as e:
    print(f"Error loading CSV: {e}")
    df = pd.DataFrame() # Fallback

# Scan examples folder for .dat files
example_files = glob.glob(os.path.join(DATA_DIR, "*.dat"))
available_study_ids = [os.path.splitext(os.path.basename(f))[0] for f in example_files]
print(f"Found examples: {available_study_ids}")

# Build Metadata for Gradio
example_metadata = {}
for sid in available_study_ids:
    if df.empty:
        example_metadata[sid] = {"diagnosis": "Unknown (CSV Missing)", "text": "N/A"}
        continue
        
    row = df[df['study_id'].astype(str) == str(sid)]
    if not row.empty:
        cols = [c for c in df.columns if 'report_' in c]
        lines = [str(row.iloc[0][c]).strip() for c in cols if pd.notna(row.iloc[0][c]) and str(row.iloc[0][c]).strip() != '']
        full_text = '\n'.join(lines)
        
        # Simple diagnosis estimation from labels for display title
        labels_vec = get_refined_labels(' '.join(lines)) 
        active_classes = [CLASSES[i] for i, val in enumerate(labels_vec) if val == 1.0]
        diagnosis = ", ".join(active_classes) if active_classes else "Normal/Other"
        
        example_metadata[sid] = {
            "diagnosis": diagnosis,
            "text": full_text
        }
    else:
        example_metadata[sid] = {"diagnosis": "Metadata Not Found", "text": "N/A"}


def load_signal(path):
    if not os.path.exists(path):
        return None
    
    gain = 200.0
    with open(path, 'rb') as f:
        raw_data = np.fromfile(f, dtype=np.int16)
        
    n_leads = 12
    n_samples = 5000
    expected_size = n_leads * n_samples
    
    if raw_data.size < expected_size:
        padded = np.zeros(expected_size, dtype=np.int16)
        padded[:raw_data.size] = raw_data
        raw_data = padded
    else:
        raw_data = raw_data[:expected_size]
        
    signal = raw_data.reshape((n_samples, n_leads)).T 
    signal = signal.astype(np.float32) / gain
    return signal

def plot_ecg(signal, title="12-Lead ECG"):
    leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
    plt.subplots_adjust(hspace=0.2)
    for i in range(12):
        axes[i].plot(signal[i], color='k', linewidth=0.8)
        axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
        axes[i].spines['top'].set_visible(False)
        axes[i].spines['right'].set_visible(False)
        axes[i].spines['bottom'].set_visible(False if i < 11 else True)
        axes[i].spines['left'].set_visible(True)
        axes[i].grid(True, linestyle='--', alpha=0.5)
    axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
    fig.suptitle(title, fontsize=16, y=0.90)
    return fig

def predict_ecg(study_id):
    path = os.path.join(DATA_DIR, f"{study_id}.dat")
    if not os.path.exists(path):
        return None, f"File not found for study {study_id}", {}
    
    signal = load_signal(path)
    if signal is None:
        return None, "Error loading signal", {}
    
    fig = plot_ecg(signal, title=f"Study {study_id}")
    
    tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(tensor_sig)
        probs = torch.sigmoid(logits).cpu().numpy()[0]
        
    results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
    
    full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
    
    return fig, results, full_text

# --- Gradio UI ---
examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
examples.sort(key=lambda x: x[0])
example_ids = [k[0] for k in examples]

if not example_ids:
    example_ids = ["No Examples Found"]

with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
    gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
    gr.Markdown("Select a study ID from the examples below to analyze the 12-lead ECG.")
    
    with gr.Row():
        with gr.Column(scale=1):
            study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0] if example_ids else None)
            gr.Markdown("### Example Descriptions")
            gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
            analyze_btn = gr.Button("Analyze ECG", variant="primary")
            
        with gr.Column(scale=2):
            plot_output = gr.Plot(label="12-Lead ECG Visualization")
            label_output = gr.Label(label="AI Predictions")
            text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
            
    analyze_btn.click(
        fn=predict_ecg,
        inputs=[study_input],
        outputs=[plot_output, label_output, text_output]
    )
    
if __name__ == "__main__":
    demo.launch()