Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- app.py +161 -0
- dataset.py +157 -0
- examples/40689238.dat +3 -0
- examples/43522917.dat +3 -0
- examples/45227415.dat +3 -0
- examples/46642833.dat +3 -0
- examples/49036311.dat +3 -0
- labels_refined.py +105 -0
- model.py +93 -0
- requirements.txt +6 -0
- resnet_advanced.pth +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/40689238.dat filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/43522917.dat filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/45227415.dat filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/46642833.dat filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/49036311.dat filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import os
|
| 8 |
+
import glob
|
| 9 |
+
from labels_refined import get_refined_labels, CLASSES
|
| 10 |
+
from model import ResNet1d
|
| 11 |
+
from dataset import MIMICECGDataset
|
| 12 |
+
|
| 13 |
+
# --- Configuration ---
|
| 14 |
+
# HF Space configuration: Data is local
|
| 15 |
+
DATA_DIR = "./examples"
|
| 16 |
+
MODEL_PATH = "resnet_advanced.pth"
|
| 17 |
+
DEVICE = torch.device("cpu") # Spaces usually CPU unless GPU requested
|
| 18 |
+
|
| 19 |
+
# --- Load Resources ---
|
| 20 |
+
print("Loading Model...")
|
| 21 |
+
model = ResNet1d(num_classes=5).to(DEVICE)
|
| 22 |
+
try:
|
| 23 |
+
state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
|
| 24 |
+
except:
|
| 25 |
+
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 26 |
+
model.load_state_dict(state_dict)
|
| 27 |
+
model.eval()
|
| 28 |
+
|
| 29 |
+
# --- Pre-defined Metadata for Examples ---
|
| 30 |
+
# Hardcoded to avoid uploading the sensitive/huge patient CSV
|
| 31 |
+
example_metadata = {
|
| 32 |
+
"40689238": {
|
| 33 |
+
"diagnosis": "Sinus Rhythm (Normal)",
|
| 34 |
+
"text": "Sinus rhythm\nNormal ECG"
|
| 35 |
+
},
|
| 36 |
+
"46642833": {
|
| 37 |
+
"diagnosis": "Atrial Fibrillation",
|
| 38 |
+
"text": "Atrial fibrillation\nRapid ventricular response"
|
| 39 |
+
},
|
| 40 |
+
"49036311": {
|
| 41 |
+
"diagnosis": "Sinus Tachycardia",
|
| 42 |
+
"text": "Sinus tachycardia\nPossible Left Atrial Enlargement"
|
| 43 |
+
},
|
| 44 |
+
"43522917": {
|
| 45 |
+
"diagnosis": "Sinus Bradycardia",
|
| 46 |
+
"text": "Sinus bradycardia\nOtherwise normal"
|
| 47 |
+
},
|
| 48 |
+
"45227415": {
|
| 49 |
+
"diagnosis": "Ventricular Tachycardia (Rare)",
|
| 50 |
+
"text": "Ventricular tachycardia\nUrgent attention required"
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def load_signal(path):
|
| 55 |
+
# Reusing logic from dataset.py
|
| 56 |
+
if not os.path.exists(path):
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
gain = 200.0
|
| 60 |
+
with open(path, 'rb') as f:
|
| 61 |
+
# File is raw int16 binary
|
| 62 |
+
raw_data = np.fromfile(f, dtype=np.int16)
|
| 63 |
+
|
| 64 |
+
n_leads = 12
|
| 65 |
+
n_samples = 5000
|
| 66 |
+
expected_size = n_leads * n_samples
|
| 67 |
+
|
| 68 |
+
if raw_data.size < expected_size:
|
| 69 |
+
padded = np.zeros(expected_size, dtype=np.int16)
|
| 70 |
+
padded[:raw_data.size] = raw_data
|
| 71 |
+
raw_data = padded
|
| 72 |
+
else:
|
| 73 |
+
raw_data = raw_data[:expected_size]
|
| 74 |
+
|
| 75 |
+
signal = raw_data.reshape((n_samples, n_leads)).T
|
| 76 |
+
signal = signal.astype(np.float32) / gain
|
| 77 |
+
return signal
|
| 78 |
+
|
| 79 |
+
def plot_ecg(signal, title="12-Lead ECG"):
|
| 80 |
+
"""Generates a matplotlib figure for the 12-lead ECG"""
|
| 81 |
+
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
|
| 82 |
+
|
| 83 |
+
fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
|
| 84 |
+
plt.subplots_adjust(hspace=0.2)
|
| 85 |
+
|
| 86 |
+
for i in range(12):
|
| 87 |
+
axes[i].plot(signal[i], color='k', linewidth=0.8)
|
| 88 |
+
axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
|
| 89 |
+
axes[i].spines['top'].set_visible(False)
|
| 90 |
+
axes[i].spines['right'].set_visible(False)
|
| 91 |
+
axes[i].spines['bottom'].set_visible(False if i < 11 else True)
|
| 92 |
+
axes[i].spines['left'].set_visible(True)
|
| 93 |
+
axes[i].grid(True, linestyle='--', alpha=0.5)
|
| 94 |
+
|
| 95 |
+
axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
|
| 96 |
+
fig.suptitle(title, fontsize=16, y=0.90)
|
| 97 |
+
|
| 98 |
+
return fig
|
| 99 |
+
|
| 100 |
+
def predict_ecg(study_id):
|
| 101 |
+
# Path is local in examples/
|
| 102 |
+
path = os.path.join(DATA_DIR, f"{study_id}.dat")
|
| 103 |
+
|
| 104 |
+
if not os.path.exists(path):
|
| 105 |
+
return None, f"File not found for study {study_id}", {}
|
| 106 |
+
|
| 107 |
+
# Load Signal
|
| 108 |
+
signal = load_signal(path)
|
| 109 |
+
if signal is None:
|
| 110 |
+
return None, "Error loading signal", {}
|
| 111 |
+
|
| 112 |
+
# Generate Plot
|
| 113 |
+
fig = plot_ecg(signal, title=f"Study {study_id}")
|
| 114 |
+
|
| 115 |
+
# Inference
|
| 116 |
+
tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE) # (1, 12, 5000)
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
logits = model(tensor_sig)
|
| 119 |
+
probs = torch.sigmoid(logits).cpu().numpy()[0]
|
| 120 |
+
|
| 121 |
+
# Format Results
|
| 122 |
+
results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 123 |
+
|
| 124 |
+
# Get True Text
|
| 125 |
+
full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
|
| 126 |
+
|
| 127 |
+
return fig, results, full_text
|
| 128 |
+
|
| 129 |
+
# --- Gradio UI ---
|
| 130 |
+
examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
|
| 131 |
+
example_ids = [k for k in example_metadata.keys()]
|
| 132 |
+
|
| 133 |
+
with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
|
| 134 |
+
gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
|
| 135 |
+
gr.Markdown("Select a study ID from the examples below to analyze the 12-lead ECG.")
|
| 136 |
+
|
| 137 |
+
with gr.Row():
|
| 138 |
+
with gr.Column(scale=1):
|
| 139 |
+
# Input
|
| 140 |
+
study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0])
|
| 141 |
+
|
| 142 |
+
# Info
|
| 143 |
+
gr.Markdown("### Example Descriptions")
|
| 144 |
+
gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
|
| 145 |
+
|
| 146 |
+
analyze_btn = gr.Button("Analyze ECG", variant="primary")
|
| 147 |
+
|
| 148 |
+
with gr.Column(scale=2):
|
| 149 |
+
# Output
|
| 150 |
+
plot_output = gr.Plot(label="12-Lead ECG Visualization")
|
| 151 |
+
label_output = gr.Label(label="AI Predictions")
|
| 152 |
+
text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
|
| 153 |
+
|
| 154 |
+
analyze_btn.click(
|
| 155 |
+
fn=predict_ecg,
|
| 156 |
+
inputs=[study_input],
|
| 157 |
+
outputs=[plot_output, label_output, text_output]
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
demo.launch()
|
dataset.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
class MIMICECGDataset(Dataset):
|
| 8 |
+
"""
|
| 9 |
+
PyTorch Dataset for MIMIC-IV-ECG.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, df, data_dir, transform=False, label_func=None):
|
| 12 |
+
"""
|
| 13 |
+
Args:
|
| 14 |
+
df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns.
|
| 15 |
+
data_dir (str): Root directory of the dataset (containing the 'files' folder).
|
| 16 |
+
transform (callable, optional): Optional transform to be applied on a sample.
|
| 17 |
+
label_func (callable, optional): Custom function to extract labels from a row.
|
| 18 |
+
"""
|
| 19 |
+
self.df = df
|
| 20 |
+
self.data_dir = data_dir
|
| 21 |
+
self.transform = transform
|
| 22 |
+
self.label_func = label_func
|
| 23 |
+
|
| 24 |
+
# MIMIC-ECG Constants
|
| 25 |
+
self.n_leads = 12
|
| 26 |
+
self.n_samples = 5000
|
| 27 |
+
self.fs = 500
|
| 28 |
+
self.gain = 200.0 # Standard gain in ADU/mV
|
| 29 |
+
|
| 30 |
+
# Define the target classes we want to detect
|
| 31 |
+
# These keys will be searched in the report columns
|
| 32 |
+
self.class_mapping = {
|
| 33 |
+
'Normally filtered': 0, # Not a diagnosis, but often present
|
| 34 |
+
'Sinus rhythm': 0,
|
| 35 |
+
'Atrial fibrillation': 1,
|
| 36 |
+
'Sinus tachycardia': 2,
|
| 37 |
+
'Sinus bradycardia': 3,
|
| 38 |
+
'Ventricular tachycardia': 4,
|
| 39 |
+
# Add more as needed
|
| 40 |
+
}
|
| 41 |
+
self.num_classes = 5 # For now
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.df)
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
if torch.is_tensor(idx):
|
| 48 |
+
idx = idx.tolist()
|
| 49 |
+
|
| 50 |
+
row = self.df.iloc[idx]
|
| 51 |
+
subj_id = str(row['subject_id'])
|
| 52 |
+
study_id = str(row['study_id'])
|
| 53 |
+
|
| 54 |
+
# Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat
|
| 55 |
+
subdir = f"p{subj_id[:4]}"
|
| 56 |
+
# Ensure we handle the folder structure correctly.
|
| 57 |
+
# Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat
|
| 58 |
+
file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat")
|
| 59 |
+
|
| 60 |
+
# 1. Load Signal
|
| 61 |
+
signal = self.load_signal_numpy(file_path)
|
| 62 |
+
|
| 63 |
+
# 2. Get Labels
|
| 64 |
+
if self.label_func:
|
| 65 |
+
# Need to pass text to label_func, or row?
|
| 66 |
+
# get_refined_labels expects text. Let's extract text here or let func handle row.
|
| 67 |
+
# Best to let func handle text so it's pure.
|
| 68 |
+
cols = [c for c in self.df.columns if 'report_' in c]
|
| 69 |
+
full_text = ' '.join([str(row[c]) for c in cols])
|
| 70 |
+
labels = self.label_func(full_text)
|
| 71 |
+
else:
|
| 72 |
+
labels = self.get_labels(row)
|
| 73 |
+
|
| 74 |
+
# 3. Return sample
|
| 75 |
+
# Signal shape: (12, 5000)
|
| 76 |
+
sample = {
|
| 77 |
+
'signal': signal,
|
| 78 |
+
'labels': labels,
|
| 79 |
+
'study_id': study_id
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
return sample
|
| 83 |
+
|
| 84 |
+
def load_signal_numpy(self, path):
|
| 85 |
+
"""
|
| 86 |
+
Reads the binary .dat file using numpy.
|
| 87 |
+
Returns a torch tensor of shape (12, 5000).
|
| 88 |
+
"""
|
| 89 |
+
# Return zeros if file is missing (to avoid crashing training loop on missing files)
|
| 90 |
+
if not os.path.exists(path):
|
| 91 |
+
return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Read binary file as 16-bit integers
|
| 95 |
+
raw_data = np.fromfile(path, dtype=np.int16)
|
| 96 |
+
|
| 97 |
+
# Check size
|
| 98 |
+
expected_size = self.n_leads * self.n_samples
|
| 99 |
+
|
| 100 |
+
if raw_data.size != expected_size:
|
| 101 |
+
# Handle truncated or wrong-sized files by padding or cutting
|
| 102 |
+
if raw_data.size < expected_size:
|
| 103 |
+
padded = np.zeros(expected_size, dtype=np.int16)
|
| 104 |
+
padded[:raw_data.size] = raw_data
|
| 105 |
+
raw_data = padded
|
| 106 |
+
else:
|
| 107 |
+
raw_data = raw_data[:expected_size]
|
| 108 |
+
|
| 109 |
+
# Reshape to (Samples, Leads) then Transpose to (Leads, Samples)
|
| 110 |
+
# stored as (samples, leads) interleaved? Usually yes in WFDB format 16
|
| 111 |
+
# Actually, standard WFDB '16' format is often interleaved.
|
| 112 |
+
# Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...)
|
| 113 |
+
signal = raw_data.reshape((self.n_samples, self.n_leads)).T
|
| 114 |
+
|
| 115 |
+
# Normalize to mV
|
| 116 |
+
signal = signal.astype(np.float32) / self.gain
|
| 117 |
+
|
| 118 |
+
return torch.from_numpy(signal)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
# print(f"Error loading {path}: {e}")
|
| 122 |
+
return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
|
| 123 |
+
|
| 124 |
+
def get_labels(self, row):
|
| 125 |
+
"""
|
| 126 |
+
Extracts labels from report columns.
|
| 127 |
+
Returns a multi-hot tensor of shape (num_classes).
|
| 128 |
+
"""
|
| 129 |
+
# Combine all report text
|
| 130 |
+
cols = [c for c in self.df.columns if 'report_' in c]
|
| 131 |
+
full_text = ' '.join([str(row[c]) for c in cols]).lower()
|
| 132 |
+
|
| 133 |
+
# Create label vector
|
| 134 |
+
label_vec = torch.zeros(self.num_classes, dtype=torch.float32)
|
| 135 |
+
|
| 136 |
+
# Simple string matching
|
| 137 |
+
# 0: Sinus Rhythm (Normal-ish)
|
| 138 |
+
if 'sinus rhythm' in full_text:
|
| 139 |
+
label_vec[0] = 1.0
|
| 140 |
+
|
| 141 |
+
# 1: Atrial Fibrillation
|
| 142 |
+
if 'atrial fibrillation' in full_text:
|
| 143 |
+
label_vec[1] = 1.0
|
| 144 |
+
|
| 145 |
+
# 2: Tachycardia
|
| 146 |
+
if 'sinus tachycardia' in full_text:
|
| 147 |
+
label_vec[2] = 1.0
|
| 148 |
+
|
| 149 |
+
# 3: Bradycardia
|
| 150 |
+
if 'sinus bradycardia' in full_text:
|
| 151 |
+
label_vec[3] = 1.0
|
| 152 |
+
|
| 153 |
+
# 4: VTach
|
| 154 |
+
if 'ventricular tachycardia' in full_text:
|
| 155 |
+
label_vec[4] = 1.0
|
| 156 |
+
|
| 157 |
+
return label_vec
|
examples/40689238.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d18f4de5faa9ab7cbe5f54d5ea4d5dddd4b57b80f5879dcc25f2e8f08d5d1c43
|
| 3 |
+
size 120000
|
examples/43522917.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60bc453a34bbff747774ca629224b4da8ac7dd4f3033ff19ab12345a6ef71cad
|
| 3 |
+
size 120000
|
examples/45227415.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e5d101d965a4fd40eee1507ec288682802579bbd78a14f4c0e4f52f62ef8bbcb
|
| 3 |
+
size 120000
|
examples/46642833.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:adc5fab3f766fc765e739ea660f1cfbd6f3ed39e1aa6218847e06d4cbe0f233b
|
| 3 |
+
size 120000
|
examples/49036311.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fea3620d35a355fce80e7e5ee9b63356f87d031e72ec3ead304f209b1b1698eb
|
| 3 |
+
size 120000
|
labels_refined.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# Define classes
|
| 6 |
+
CLASSES = [
|
| 7 |
+
'Sinus Rhythm', # 0
|
| 8 |
+
'Atrial Fibrillation', # 1
|
| 9 |
+
'Sinus Tachycardia', # 2
|
| 10 |
+
'Sinus Bradycardia', # 3
|
| 11 |
+
'Ventricular Tachycardia' # 4
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
def get_refined_labels(text):
|
| 15 |
+
"""
|
| 16 |
+
Parses ECG report text to extract diagnostic labels using Regex with negation handling.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
text (str): Combined text from report columns.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.Tensor: Multi-hot encoded label vector.
|
| 23 |
+
"""
|
| 24 |
+
text = text.lower()
|
| 25 |
+
labels = torch.zeros(len(CLASSES), dtype=torch.float32)
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------
|
| 28 |
+
# Helper: Check for positive mention (ignoring negations)
|
| 29 |
+
# ---------------------------------------------------------
|
| 30 |
+
def has_condition(patterns, exclusion_patterns=None):
|
| 31 |
+
if exclusion_patterns:
|
| 32 |
+
for excl in exclusion_patterns:
|
| 33 |
+
if re.search(excl, text):
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
for pat in patterns:
|
| 37 |
+
# Check for negation preceding the match
|
| 38 |
+
# finding all matches
|
| 39 |
+
matches = re.finditer(pat, text)
|
| 40 |
+
for match in matches:
|
| 41 |
+
start_idx = match.start()
|
| 42 |
+
# Look at the window before the match (e.g., 20 chars)
|
| 43 |
+
context_before = text[max(0, start_idx-25):start_idx]
|
| 44 |
+
|
| 45 |
+
# Negation triggers
|
| 46 |
+
negations = ['no ', 'not ', 'rule out ', 'denies ', 'absence of ', 'free of ']
|
| 47 |
+
if any(neg in context_before for neg in negations):
|
| 48 |
+
continue # This match is negated
|
| 49 |
+
|
| 50 |
+
return True # Found a positive, non-negated match
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------
|
| 54 |
+
# Class 0: Sinus Rhythm
|
| 55 |
+
# ---------------------------------------------------------
|
| 56 |
+
# "Sinus rhythm" is often the default, but we should check for it explicitly.
|
| 57 |
+
if has_condition([r'sinus rhythm']):
|
| 58 |
+
labels[0] = 1.0
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------
|
| 61 |
+
# Class 1: Atrial Fibrillation
|
| 62 |
+
# ---------------------------------------------------------
|
| 63 |
+
# Synonyms: AFib, A-fib, Atrial Fib
|
| 64 |
+
if has_condition([r'atrial fibrillation', r'afib', r'a-fib', r'atrial fib']):
|
| 65 |
+
labels[1] = 1.0
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------
|
| 68 |
+
# Class 2: Sinus Tachycardia
|
| 69 |
+
# ---------------------------------------------------------
|
| 70 |
+
if has_condition([r'sinus tachycardia']):
|
| 71 |
+
labels[2] = 1.0
|
| 72 |
+
|
| 73 |
+
# ---------------------------------------------------------
|
| 74 |
+
# Class 3: Sinus Bradycardia
|
| 75 |
+
# ---------------------------------------------------------
|
| 76 |
+
if has_condition([r'sinus bradycardia']):
|
| 77 |
+
labels[3] = 1.0
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------
|
| 80 |
+
# Class 4: Ventricular Tachycardia
|
| 81 |
+
# ---------------------------------------------------------
|
| 82 |
+
# Synonyms: VTach, V-Tach, VT
|
| 83 |
+
# Be careful with "VT" matching random text
|
| 84 |
+
if has_condition([r'ventricular tachycardia', r'vtach', r'\bvt\b', r'v-tach']):
|
| 85 |
+
labels[4] = 1.0
|
| 86 |
+
|
| 87 |
+
return labels
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
# Test cases
|
| 91 |
+
test_sentences = [
|
| 92 |
+
("Normal sinus rhythm", [1, 0, 0, 0, 0]),
|
| 93 |
+
("Atrial fibrillation with rapid ventricular response", [0, 1, 0, 0, 0]),
|
| 94 |
+
("No atrial fibrillation detected", [0, 0, 0, 0, 0]),
|
| 95 |
+
("Sinus tachycardia", [0, 0, 1, 0, 0]),
|
| 96 |
+
("Rule out ventricular tachycardia", [0, 0, 0, 0, 0]),
|
| 97 |
+
("Patient has history of afib", [0, 1, 0, 0, 0]), # History might be ambiguous, but usually valid for label
|
| 98 |
+
("Sinus bradycardia observed", [0, 0, 0, 1, 0])
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
print("Running Regex Label Tests...")
|
| 102 |
+
for txt, expected in test_sentences:
|
| 103 |
+
res = get_refined_labels(txt)
|
| 104 |
+
match = torch.all(res == torch.tensor(expected)).item()
|
| 105 |
+
print(f"'{txt}' -> {res.numpy()} [{'✅' if match else '❌'}]")
|
model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ResNetBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
|
| 7 |
+
super(ResNetBlock, self).__init__()
|
| 8 |
+
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3, bias=False)
|
| 9 |
+
self.bn1 = nn.BatchNorm1d(out_channels)
|
| 10 |
+
self.relu = nn.ReLU(inplace=True)
|
| 11 |
+
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)
|
| 12 |
+
self.bn2 = nn.BatchNorm1d(out_channels)
|
| 13 |
+
self.downsample = downsample
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
identity = x
|
| 17 |
+
if self.downsample is not None:
|
| 18 |
+
identity = self.downsample(x)
|
| 19 |
+
|
| 20 |
+
out = self.conv1(x)
|
| 21 |
+
out = self.bn1(out)
|
| 22 |
+
out = self.relu(out)
|
| 23 |
+
|
| 24 |
+
out = self.conv2(out)
|
| 25 |
+
out = self.bn2(out)
|
| 26 |
+
|
| 27 |
+
out += identity
|
| 28 |
+
out = self.relu(out)
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
class ResNet1d(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
ResNet-1D for ECG Classification.
|
| 34 |
+
Adapted from 'Time Series Classification from Scratch with Deep Neural Networks: A Strong Baseline' (Wang et al. 2017)
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, num_classes=5):
|
| 37 |
+
super(ResNet1d, self).__init__()
|
| 38 |
+
|
| 39 |
+
self.inplanes = 64
|
| 40 |
+
# Initial: 12 leads -> 64 channels
|
| 41 |
+
self.conv1 = nn.Conv1d(12, 64, kernel_size=15, stride=2, padding=7, bias=False)
|
| 42 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 43 |
+
self.relu = nn.ReLU(inplace=True)
|
| 44 |
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 45 |
+
|
| 46 |
+
# Layers
|
| 47 |
+
self.layer1 = self._make_layer(64, 2, stride=1)
|
| 48 |
+
self.layer2 = self._make_layer(128, 2, stride=2)
|
| 49 |
+
self.layer3 = self._make_layer(256, 2, stride=2)
|
| 50 |
+
self.layer4 = self._make_layer(512, 2, stride=2)
|
| 51 |
+
|
| 52 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 53 |
+
self.fc = nn.Linear(512, num_classes)
|
| 54 |
+
|
| 55 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 56 |
+
downsample = None
|
| 57 |
+
if stride != 1 or self.inplanes != planes:
|
| 58 |
+
downsample = nn.Sequential(
|
| 59 |
+
nn.Conv1d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False),
|
| 60 |
+
nn.BatchNorm1d(planes),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
layers = []
|
| 64 |
+
layers.append(ResNetBlock(self.inplanes, planes, stride, downsample))
|
| 65 |
+
self.inplanes = planes
|
| 66 |
+
for _ in range(1, blocks):
|
| 67 |
+
layers.append(ResNetBlock(self.inplanes, planes))
|
| 68 |
+
|
| 69 |
+
return nn.Sequential(*layers)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.conv1(x)
|
| 73 |
+
x = self.bn1(x)
|
| 74 |
+
x = self.relu(x)
|
| 75 |
+
x = self.maxpool(x)
|
| 76 |
+
|
| 77 |
+
x = self.layer1(x)
|
| 78 |
+
x = self.layer2(x)
|
| 79 |
+
x = self.layer3(x)
|
| 80 |
+
x = self.layer4(x)
|
| 81 |
+
|
| 82 |
+
x = self.avgpool(x)
|
| 83 |
+
x = x.view(x.size(0), -1)
|
| 84 |
+
x = self.fc(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
# Test
|
| 89 |
+
model = ResNet1d(num_classes=5)
|
| 90 |
+
dummy = torch.randn(2, 12, 5000)
|
| 91 |
+
out = model(dummy)
|
| 92 |
+
print(f"Input: {dummy.shape}")
|
| 93 |
+
print(f"Output: {out.shape}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
matplotlib
|
| 5 |
+
gradio
|
| 6 |
+
scipy
|
resnet_advanced.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22f939e89d1fe5528a24f4e4894e507643f39307e52b3b57420fcb6820db9d50
|
| 3 |
+
size 35039598
|