IFMedTechdemo commited on
Commit
e3b4744
·
verified ·
1 Parent(s): 9c45578

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/40689238.dat filter=lfs diff=lfs merge=lfs -text
37
+ examples/43522917.dat filter=lfs diff=lfs merge=lfs -text
38
+ examples/45227415.dat filter=lfs diff=lfs merge=lfs -text
39
+ examples/46642833.dat filter=lfs diff=lfs merge=lfs -text
40
+ examples/49036311.dat filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import os
8
+ import glob
9
+ from labels_refined import get_refined_labels, CLASSES
10
+ from model import ResNet1d
11
+ from dataset import MIMICECGDataset
12
+
13
+ # --- Configuration ---
14
+ # HF Space configuration: Data is local
15
+ DATA_DIR = "./examples"
16
+ MODEL_PATH = "resnet_advanced.pth"
17
+ DEVICE = torch.device("cpu") # Spaces usually CPU unless GPU requested
18
+
19
+ # --- Load Resources ---
20
+ print("Loading Model...")
21
+ model = ResNet1d(num_classes=5).to(DEVICE)
22
+ try:
23
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
24
+ except:
25
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
26
+ model.load_state_dict(state_dict)
27
+ model.eval()
28
+
29
+ # --- Pre-defined Metadata for Examples ---
30
+ # Hardcoded to avoid uploading the sensitive/huge patient CSV
31
+ example_metadata = {
32
+ "40689238": {
33
+ "diagnosis": "Sinus Rhythm (Normal)",
34
+ "text": "Sinus rhythm\nNormal ECG"
35
+ },
36
+ "46642833": {
37
+ "diagnosis": "Atrial Fibrillation",
38
+ "text": "Atrial fibrillation\nRapid ventricular response"
39
+ },
40
+ "49036311": {
41
+ "diagnosis": "Sinus Tachycardia",
42
+ "text": "Sinus tachycardia\nPossible Left Atrial Enlargement"
43
+ },
44
+ "43522917": {
45
+ "diagnosis": "Sinus Bradycardia",
46
+ "text": "Sinus bradycardia\nOtherwise normal"
47
+ },
48
+ "45227415": {
49
+ "diagnosis": "Ventricular Tachycardia (Rare)",
50
+ "text": "Ventricular tachycardia\nUrgent attention required"
51
+ }
52
+ }
53
+
54
+ def load_signal(path):
55
+ # Reusing logic from dataset.py
56
+ if not os.path.exists(path):
57
+ return None
58
+
59
+ gain = 200.0
60
+ with open(path, 'rb') as f:
61
+ # File is raw int16 binary
62
+ raw_data = np.fromfile(f, dtype=np.int16)
63
+
64
+ n_leads = 12
65
+ n_samples = 5000
66
+ expected_size = n_leads * n_samples
67
+
68
+ if raw_data.size < expected_size:
69
+ padded = np.zeros(expected_size, dtype=np.int16)
70
+ padded[:raw_data.size] = raw_data
71
+ raw_data = padded
72
+ else:
73
+ raw_data = raw_data[:expected_size]
74
+
75
+ signal = raw_data.reshape((n_samples, n_leads)).T
76
+ signal = signal.astype(np.float32) / gain
77
+ return signal
78
+
79
+ def plot_ecg(signal, title="12-Lead ECG"):
80
+ """Generates a matplotlib figure for the 12-lead ECG"""
81
+ leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
82
+
83
+ fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
84
+ plt.subplots_adjust(hspace=0.2)
85
+
86
+ for i in range(12):
87
+ axes[i].plot(signal[i], color='k', linewidth=0.8)
88
+ axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
89
+ axes[i].spines['top'].set_visible(False)
90
+ axes[i].spines['right'].set_visible(False)
91
+ axes[i].spines['bottom'].set_visible(False if i < 11 else True)
92
+ axes[i].spines['left'].set_visible(True)
93
+ axes[i].grid(True, linestyle='--', alpha=0.5)
94
+
95
+ axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
96
+ fig.suptitle(title, fontsize=16, y=0.90)
97
+
98
+ return fig
99
+
100
+ def predict_ecg(study_id):
101
+ # Path is local in examples/
102
+ path = os.path.join(DATA_DIR, f"{study_id}.dat")
103
+
104
+ if not os.path.exists(path):
105
+ return None, f"File not found for study {study_id}", {}
106
+
107
+ # Load Signal
108
+ signal = load_signal(path)
109
+ if signal is None:
110
+ return None, "Error loading signal", {}
111
+
112
+ # Generate Plot
113
+ fig = plot_ecg(signal, title=f"Study {study_id}")
114
+
115
+ # Inference
116
+ tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE) # (1, 12, 5000)
117
+ with torch.no_grad():
118
+ logits = model(tensor_sig)
119
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
120
+
121
+ # Format Results
122
+ results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
123
+
124
+ # Get True Text
125
+ full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
126
+
127
+ return fig, results, full_text
128
+
129
+ # --- Gradio UI ---
130
+ examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
131
+ example_ids = [k for k in example_metadata.keys()]
132
+
133
+ with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
134
+ gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
135
+ gr.Markdown("Select a study ID from the examples below to analyze the 12-lead ECG.")
136
+
137
+ with gr.Row():
138
+ with gr.Column(scale=1):
139
+ # Input
140
+ study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0])
141
+
142
+ # Info
143
+ gr.Markdown("### Example Descriptions")
144
+ gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
145
+
146
+ analyze_btn = gr.Button("Analyze ECG", variant="primary")
147
+
148
+ with gr.Column(scale=2):
149
+ # Output
150
+ plot_output = gr.Plot(label="12-Lead ECG Visualization")
151
+ label_output = gr.Label(label="AI Predictions")
152
+ text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
153
+
154
+ analyze_btn.click(
155
+ fn=predict_ecg,
156
+ inputs=[study_input],
157
+ outputs=[plot_output, label_output, text_output]
158
+ )
159
+
160
+ if __name__ == "__main__":
161
+ demo.launch()
dataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+
7
+ class MIMICECGDataset(Dataset):
8
+ """
9
+ PyTorch Dataset for MIMIC-IV-ECG.
10
+ """
11
+ def __init__(self, df, data_dir, transform=False, label_func=None):
12
+ """
13
+ Args:
14
+ df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns.
15
+ data_dir (str): Root directory of the dataset (containing the 'files' folder).
16
+ transform (callable, optional): Optional transform to be applied on a sample.
17
+ label_func (callable, optional): Custom function to extract labels from a row.
18
+ """
19
+ self.df = df
20
+ self.data_dir = data_dir
21
+ self.transform = transform
22
+ self.label_func = label_func
23
+
24
+ # MIMIC-ECG Constants
25
+ self.n_leads = 12
26
+ self.n_samples = 5000
27
+ self.fs = 500
28
+ self.gain = 200.0 # Standard gain in ADU/mV
29
+
30
+ # Define the target classes we want to detect
31
+ # These keys will be searched in the report columns
32
+ self.class_mapping = {
33
+ 'Normally filtered': 0, # Not a diagnosis, but often present
34
+ 'Sinus rhythm': 0,
35
+ 'Atrial fibrillation': 1,
36
+ 'Sinus tachycardia': 2,
37
+ 'Sinus bradycardia': 3,
38
+ 'Ventricular tachycardia': 4,
39
+ # Add more as needed
40
+ }
41
+ self.num_classes = 5 # For now
42
+
43
+ def __len__(self):
44
+ return len(self.df)
45
+
46
+ def __getitem__(self, idx):
47
+ if torch.is_tensor(idx):
48
+ idx = idx.tolist()
49
+
50
+ row = self.df.iloc[idx]
51
+ subj_id = str(row['subject_id'])
52
+ study_id = str(row['study_id'])
53
+
54
+ # Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat
55
+ subdir = f"p{subj_id[:4]}"
56
+ # Ensure we handle the folder structure correctly.
57
+ # Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat
58
+ file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat")
59
+
60
+ # 1. Load Signal
61
+ signal = self.load_signal_numpy(file_path)
62
+
63
+ # 2. Get Labels
64
+ if self.label_func:
65
+ # Need to pass text to label_func, or row?
66
+ # get_refined_labels expects text. Let's extract text here or let func handle row.
67
+ # Best to let func handle text so it's pure.
68
+ cols = [c for c in self.df.columns if 'report_' in c]
69
+ full_text = ' '.join([str(row[c]) for c in cols])
70
+ labels = self.label_func(full_text)
71
+ else:
72
+ labels = self.get_labels(row)
73
+
74
+ # 3. Return sample
75
+ # Signal shape: (12, 5000)
76
+ sample = {
77
+ 'signal': signal,
78
+ 'labels': labels,
79
+ 'study_id': study_id
80
+ }
81
+
82
+ return sample
83
+
84
+ def load_signal_numpy(self, path):
85
+ """
86
+ Reads the binary .dat file using numpy.
87
+ Returns a torch tensor of shape (12, 5000).
88
+ """
89
+ # Return zeros if file is missing (to avoid crashing training loop on missing files)
90
+ if not os.path.exists(path):
91
+ return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
92
+
93
+ try:
94
+ # Read binary file as 16-bit integers
95
+ raw_data = np.fromfile(path, dtype=np.int16)
96
+
97
+ # Check size
98
+ expected_size = self.n_leads * self.n_samples
99
+
100
+ if raw_data.size != expected_size:
101
+ # Handle truncated or wrong-sized files by padding or cutting
102
+ if raw_data.size < expected_size:
103
+ padded = np.zeros(expected_size, dtype=np.int16)
104
+ padded[:raw_data.size] = raw_data
105
+ raw_data = padded
106
+ else:
107
+ raw_data = raw_data[:expected_size]
108
+
109
+ # Reshape to (Samples, Leads) then Transpose to (Leads, Samples)
110
+ # stored as (samples, leads) interleaved? Usually yes in WFDB format 16
111
+ # Actually, standard WFDB '16' format is often interleaved.
112
+ # Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...)
113
+ signal = raw_data.reshape((self.n_samples, self.n_leads)).T
114
+
115
+ # Normalize to mV
116
+ signal = signal.astype(np.float32) / self.gain
117
+
118
+ return torch.from_numpy(signal)
119
+
120
+ except Exception as e:
121
+ # print(f"Error loading {path}: {e}")
122
+ return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
123
+
124
+ def get_labels(self, row):
125
+ """
126
+ Extracts labels from report columns.
127
+ Returns a multi-hot tensor of shape (num_classes).
128
+ """
129
+ # Combine all report text
130
+ cols = [c for c in self.df.columns if 'report_' in c]
131
+ full_text = ' '.join([str(row[c]) for c in cols]).lower()
132
+
133
+ # Create label vector
134
+ label_vec = torch.zeros(self.num_classes, dtype=torch.float32)
135
+
136
+ # Simple string matching
137
+ # 0: Sinus Rhythm (Normal-ish)
138
+ if 'sinus rhythm' in full_text:
139
+ label_vec[0] = 1.0
140
+
141
+ # 1: Atrial Fibrillation
142
+ if 'atrial fibrillation' in full_text:
143
+ label_vec[1] = 1.0
144
+
145
+ # 2: Tachycardia
146
+ if 'sinus tachycardia' in full_text:
147
+ label_vec[2] = 1.0
148
+
149
+ # 3: Bradycardia
150
+ if 'sinus bradycardia' in full_text:
151
+ label_vec[3] = 1.0
152
+
153
+ # 4: VTach
154
+ if 'ventricular tachycardia' in full_text:
155
+ label_vec[4] = 1.0
156
+
157
+ return label_vec
examples/40689238.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d18f4de5faa9ab7cbe5f54d5ea4d5dddd4b57b80f5879dcc25f2e8f08d5d1c43
3
+ size 120000
examples/43522917.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60bc453a34bbff747774ca629224b4da8ac7dd4f3033ff19ab12345a6ef71cad
3
+ size 120000
examples/45227415.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5d101d965a4fd40eee1507ec288682802579bbd78a14f4c0e4f52f62ef8bbcb
3
+ size 120000
examples/46642833.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adc5fab3f766fc765e739ea660f1cfbd6f3ed39e1aa6218847e06d4cbe0f233b
3
+ size 120000
examples/49036311.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fea3620d35a355fce80e7e5ee9b63356f87d031e72ec3ead304f209b1b1698eb
3
+ size 120000
labels_refined.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import numpy as np
4
+
5
+ # Define classes
6
+ CLASSES = [
7
+ 'Sinus Rhythm', # 0
8
+ 'Atrial Fibrillation', # 1
9
+ 'Sinus Tachycardia', # 2
10
+ 'Sinus Bradycardia', # 3
11
+ 'Ventricular Tachycardia' # 4
12
+ ]
13
+
14
+ def get_refined_labels(text):
15
+ """
16
+ Parses ECG report text to extract diagnostic labels using Regex with negation handling.
17
+
18
+ Args:
19
+ text (str): Combined text from report columns.
20
+
21
+ Returns:
22
+ torch.Tensor: Multi-hot encoded label vector.
23
+ """
24
+ text = text.lower()
25
+ labels = torch.zeros(len(CLASSES), dtype=torch.float32)
26
+
27
+ # ---------------------------------------------------------
28
+ # Helper: Check for positive mention (ignoring negations)
29
+ # ---------------------------------------------------------
30
+ def has_condition(patterns, exclusion_patterns=None):
31
+ if exclusion_patterns:
32
+ for excl in exclusion_patterns:
33
+ if re.search(excl, text):
34
+ return False
35
+
36
+ for pat in patterns:
37
+ # Check for negation preceding the match
38
+ # finding all matches
39
+ matches = re.finditer(pat, text)
40
+ for match in matches:
41
+ start_idx = match.start()
42
+ # Look at the window before the match (e.g., 20 chars)
43
+ context_before = text[max(0, start_idx-25):start_idx]
44
+
45
+ # Negation triggers
46
+ negations = ['no ', 'not ', 'rule out ', 'denies ', 'absence of ', 'free of ']
47
+ if any(neg in context_before for neg in negations):
48
+ continue # This match is negated
49
+
50
+ return True # Found a positive, non-negated match
51
+ return False
52
+
53
+ # ---------------------------------------------------------
54
+ # Class 0: Sinus Rhythm
55
+ # ---------------------------------------------------------
56
+ # "Sinus rhythm" is often the default, but we should check for it explicitly.
57
+ if has_condition([r'sinus rhythm']):
58
+ labels[0] = 1.0
59
+
60
+ # ---------------------------------------------------------
61
+ # Class 1: Atrial Fibrillation
62
+ # ---------------------------------------------------------
63
+ # Synonyms: AFib, A-fib, Atrial Fib
64
+ if has_condition([r'atrial fibrillation', r'afib', r'a-fib', r'atrial fib']):
65
+ labels[1] = 1.0
66
+
67
+ # ---------------------------------------------------------
68
+ # Class 2: Sinus Tachycardia
69
+ # ---------------------------------------------------------
70
+ if has_condition([r'sinus tachycardia']):
71
+ labels[2] = 1.0
72
+
73
+ # ---------------------------------------------------------
74
+ # Class 3: Sinus Bradycardia
75
+ # ---------------------------------------------------------
76
+ if has_condition([r'sinus bradycardia']):
77
+ labels[3] = 1.0
78
+
79
+ # ---------------------------------------------------------
80
+ # Class 4: Ventricular Tachycardia
81
+ # ---------------------------------------------------------
82
+ # Synonyms: VTach, V-Tach, VT
83
+ # Be careful with "VT" matching random text
84
+ if has_condition([r'ventricular tachycardia', r'vtach', r'\bvt\b', r'v-tach']):
85
+ labels[4] = 1.0
86
+
87
+ return labels
88
+
89
+ if __name__ == "__main__":
90
+ # Test cases
91
+ test_sentences = [
92
+ ("Normal sinus rhythm", [1, 0, 0, 0, 0]),
93
+ ("Atrial fibrillation with rapid ventricular response", [0, 1, 0, 0, 0]),
94
+ ("No atrial fibrillation detected", [0, 0, 0, 0, 0]),
95
+ ("Sinus tachycardia", [0, 0, 1, 0, 0]),
96
+ ("Rule out ventricular tachycardia", [0, 0, 0, 0, 0]),
97
+ ("Patient has history of afib", [0, 1, 0, 0, 0]), # History might be ambiguous, but usually valid for label
98
+ ("Sinus bradycardia observed", [0, 0, 0, 1, 0])
99
+ ]
100
+
101
+ print("Running Regex Label Tests...")
102
+ for txt, expected in test_sentences:
103
+ res = get_refined_labels(txt)
104
+ match = torch.all(res == torch.tensor(expected)).item()
105
+ print(f"'{txt}' -> {res.numpy()} [{'✅' if match else '❌'}]")
model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResNetBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
7
+ super(ResNetBlock, self).__init__()
8
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3, bias=False)
9
+ self.bn1 = nn.BatchNorm1d(out_channels)
10
+ self.relu = nn.ReLU(inplace=True)
11
+ self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)
12
+ self.bn2 = nn.BatchNorm1d(out_channels)
13
+ self.downsample = downsample
14
+
15
+ def forward(self, x):
16
+ identity = x
17
+ if self.downsample is not None:
18
+ identity = self.downsample(x)
19
+
20
+ out = self.conv1(x)
21
+ out = self.bn1(out)
22
+ out = self.relu(out)
23
+
24
+ out = self.conv2(out)
25
+ out = self.bn2(out)
26
+
27
+ out += identity
28
+ out = self.relu(out)
29
+ return out
30
+
31
+ class ResNet1d(nn.Module):
32
+ """
33
+ ResNet-1D for ECG Classification.
34
+ Adapted from 'Time Series Classification from Scratch with Deep Neural Networks: A Strong Baseline' (Wang et al. 2017)
35
+ """
36
+ def __init__(self, num_classes=5):
37
+ super(ResNet1d, self).__init__()
38
+
39
+ self.inplanes = 64
40
+ # Initial: 12 leads -> 64 channels
41
+ self.conv1 = nn.Conv1d(12, 64, kernel_size=15, stride=2, padding=7, bias=False)
42
+ self.bn1 = nn.BatchNorm1d(64)
43
+ self.relu = nn.ReLU(inplace=True)
44
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
45
+
46
+ # Layers
47
+ self.layer1 = self._make_layer(64, 2, stride=1)
48
+ self.layer2 = self._make_layer(128, 2, stride=2)
49
+ self.layer3 = self._make_layer(256, 2, stride=2)
50
+ self.layer4 = self._make_layer(512, 2, stride=2)
51
+
52
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
53
+ self.fc = nn.Linear(512, num_classes)
54
+
55
+ def _make_layer(self, planes, blocks, stride=1):
56
+ downsample = None
57
+ if stride != 1 or self.inplanes != planes:
58
+ downsample = nn.Sequential(
59
+ nn.Conv1d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False),
60
+ nn.BatchNorm1d(planes),
61
+ )
62
+
63
+ layers = []
64
+ layers.append(ResNetBlock(self.inplanes, planes, stride, downsample))
65
+ self.inplanes = planes
66
+ for _ in range(1, blocks):
67
+ layers.append(ResNetBlock(self.inplanes, planes))
68
+
69
+ return nn.Sequential(*layers)
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = self.bn1(x)
74
+ x = self.relu(x)
75
+ x = self.maxpool(x)
76
+
77
+ x = self.layer1(x)
78
+ x = self.layer2(x)
79
+ x = self.layer3(x)
80
+ x = self.layer4(x)
81
+
82
+ x = self.avgpool(x)
83
+ x = x.view(x.size(0), -1)
84
+ x = self.fc(x)
85
+ return x
86
+
87
+ if __name__ == "__main__":
88
+ # Test
89
+ model = ResNet1d(num_classes=5)
90
+ dummy = torch.randn(2, 12, 5000)
91
+ out = model(dummy)
92
+ print(f"Input: {dummy.shape}")
93
+ print(f"Output: {out.shape}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ numpy
4
+ matplotlib
5
+ gradio
6
+ scipy
resnet_advanced.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22f939e89d1fe5528a24f4e4894e507643f39307e52b3b57420fcb6820db9d50
3
+ size 35039598