vlbthambawita commited on
Commit
d8c3769
·
1 Parent(s): 51369e5
Files changed (5) hide show
  1. README.md +58 -0
  2. app.py +433 -0
  3. categorical_imn_core.py +307 -0
  4. requirements.txt +8 -0
  5. single_linear_imn_core.py +307 -0
README.md CHANGED
@@ -12,3 +12,61 @@ short_description: Interpretable Mesomorphic Neural Networks for 12-Lead ECG
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ ## MesomorphicECG XAI Space
17
+
18
+ This Space hosts an interactive Gradio app for the **mesomorphicECG** models in
19
+ `SEARCH-IHI/mesomorphicECG` (`https://huggingface.co/SEARCH-IHI/mesomorphicECG`).
20
+
21
+ The app:
22
+ - Loads **IMN** checkpoints (categorical and single-linear) from the model repo.
23
+ - Lets you choose sampling rate (100 / 500 Hz) and task:
24
+ `norm_vs_cd`, `norm_vs_hyp`, `norm_vs_mi`, `norm_vs_sttc`.
25
+ - Uses pre-packaged PTB-XL examples stored as binary `.npz` files in this Space.
26
+ - Visualizes intrinsic IMN feature attributions (Impact = w·x) as a lead × segment heatmap
27
+ together with per-lead ECG traces.
28
+
29
+ ### Files
30
+
31
+ - `app.py` – main Gradio application.
32
+ - `single_linear_imn_core.py` – core single-linear IMN model for inference.
33
+ - `categorical_imn_core.py` – core categorical IMN model for inference.
34
+ - `requirements.txt` – Python dependencies for this Space.
35
+
36
+ ### Required data binaries
37
+
38
+ For each combination of **sampling rate** and **task**, the app expects a `.npz` file:
39
+
40
+ - 100 Hz:
41
+ - `data/ptbxl_100hz_norm_vs_cd_test.npz`
42
+ - `data/ptbxl_100hz_norm_vs_hyp_test.npz`
43
+ - `data/ptbxl_100hz_norm_vs_mi_test.npz`
44
+ - `data/ptbxl_100hz_norm_vs_sttc_test.npz`
45
+ - 500 Hz:
46
+ - `data/ptbxl_500hz_norm_vs_cd_test.npz`
47
+ - `data/ptbxl_500hz_norm_vs_hyp_test.npz`
48
+ - `data/ptbxl_500hz_norm_vs_mi_test.npz`
49
+ - `data/ptbxl_500hz_norm_vs_sttc_test.npz`
50
+
51
+ Each `.npz` should contain:
52
+
53
+ - `signals` – float32 array `[N, 12, L]` (z-scoring is done in the app).
54
+ - `labels` – array `[N]` with 0 (NORM) / 1 (POS_CLASS) for the chosen task.
55
+ - `reports` – object array `[N]` with clinical notes (strings).
56
+ - `age` – array `[N]` (e.g. int or float).
57
+ - `sex` – object array `[N]` (e.g. `'M'`, `'F'`, or empty).
58
+ - `ecg_id` – array `[N]` with integer ECG identifiers.
59
+
60
+ You can prepare these from PTB-XL using the same task definition and
61
+ window length / sampling rate as in the training scripts, then upload
62
+ them into this Space under the `data/` directory.
63
+
64
+ ### Run locally
65
+
66
+ ```bash
67
+ pip install -r requirements.txt
68
+ python app.py
69
+ ```
70
+
71
+ On Hugging Face Spaces, `app.py` is loaded automatically.
72
+
app.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """
4
+ MesomorphicECG XAI Gradio app for Hugging Face Spaces.
5
+
6
+ This version focuses on:
7
+ - Selecting sampling rate (100 / 500 Hz), model type (categorical vs single-linear),
8
+ and task (norm_vs_cd / norm_vs_hyp / norm_vs_mi / norm_vs_sttc).
9
+ - Loading pre-packaged ECG examples from local binary .npz files in this Space.
10
+ - Downloading the corresponding IMN checkpoint from
11
+ `SEARCH-IHI/mesomorphicECG` on the Hugging Face Hub.
12
+ - Running inference and visualizing intrinsic feature attributions
13
+ (Impact = w * x) as a lead × segment heatmap plus per-lead ECG traces.
14
+
15
+ Data binaries
16
+ -------------
17
+ For each (sampling_rate, task) pair you should provide a `.npz` file as
18
+ configured in DATA_FILES below, with keys:
19
+
20
+ signals : float32 array [N, 12, L]
21
+ labels : float32/int array [N] with 0 (NORM) / 1 (POS_CLASS)
22
+ reports : object array [N] of clinical notes
23
+ age : array [N]
24
+ sex : object array [N]
25
+ ecg_id : array [N]
26
+ """
27
+
28
+ import os
29
+ from functools import lru_cache
30
+ from typing import Any, Dict, List, Optional, Tuple
31
+
32
+ import numpy as np
33
+ import torch
34
+ import matplotlib
35
+
36
+ matplotlib.use("Agg")
37
+ import matplotlib.pyplot as plt # noqa: E402
38
+ import gradio as gr # noqa: E402
39
+ from huggingface_hub import hf_hub_download, list_repo_files # noqa: E402
40
+
41
+ import single_linear_imn_core as sl_core # noqa: E402
42
+ import categorical_imn_core as cat_core # noqa: E402
43
+
44
+
45
+ HF_MODEL_REPO = "SEARCH-IHI/mesomorphicECG"
46
+
47
+ TASK_TO_POS = {
48
+ "norm_vs_mi": "MI",
49
+ "norm_vs_sttc": "STTC",
50
+ "norm_vs_cd": "CD",
51
+ "norm_vs_hyp": "HYP",
52
+ }
53
+
54
+ LEAD_NAMES = sl_core.DEFAULT_LEAD_NAMES
55
+
56
+
57
+ # Mapping from (sampling_rate, task) -> local data binary.
58
+ DATA_FILES: Dict[Tuple[int, str], str] = {
59
+ # 100 Hz
60
+ (100, "norm_vs_cd"): "data/ptbxl_100hz_norm_vs_cd_test.npz",
61
+ (100, "norm_vs_hyp"): "data/ptbxl_100hz_norm_vs_hyp_test.npz",
62
+ (100, "norm_vs_mi"): "data/ptbxl_100hz_norm_vs_mi_test.npz",
63
+ (100, "norm_vs_sttc"): "data/ptbxl_100hz_norm_vs_sttc_test.npz",
64
+ # 500 Hz
65
+ (500, "norm_vs_cd"): "data/ptbxl_500hz_norm_vs_cd_test.npz",
66
+ (500, "norm_vs_hyp"): "data/ptbxl_500hz_norm_vs_hyp_test.npz",
67
+ (500, "norm_vs_mi"): "data/ptbxl_500hz_norm_vs_mi_test.npz",
68
+ (500, "norm_vs_sttc"): "data/ptbxl_500hz_norm_vs_sttc_test.npz",
69
+ }
70
+
71
+
72
+ DATA_CACHE: Dict[Tuple[int, str], Dict[str, Any]] = {}
73
+ MODEL_CACHE: Dict[Tuple[str, int, str], Dict[str, Any]] = {}
74
+
75
+
76
+ def zscore_per_lead(x: np.ndarray) -> np.ndarray:
77
+ """Per-lead z-score normalization."""
78
+ mean = x.mean(axis=1, keepdims=True)
79
+ std = x.std(axis=1, keepdims=True).clip(min=1e-6)
80
+ return ((x - mean) / std).astype(np.float32)
81
+
82
+
83
+ @lru_cache(maxsize=None)
84
+ def _list_model_repo_files() -> List[str]:
85
+ return list_repo_files(repo_id=HF_MODEL_REPO, repo_type="model")
86
+
87
+
88
+ def _resolve_ckpt_filename(model_type: str, sampling_rate: int, task: str) -> str:
89
+ if model_type == "single_linear":
90
+ category = f"single_linear_imn_{sampling_rate}hz"
91
+ else:
92
+ category = f"categorical_imn_{sampling_rate}hz"
93
+
94
+ prefix = f"{category}/{task}/"
95
+ files = _list_model_repo_files()
96
+ candidates = [f for f in files if f.startswith(prefix) and f.endswith(".ckpt")]
97
+ if not candidates:
98
+ raise FileNotFoundError(
99
+ f"No checkpoint (.ckpt) found in repo {HF_MODEL_REPO} under {prefix}. "
100
+ "Ensure upload_best_checkpoints_to_hf.py has populated this path."
101
+ )
102
+ best_style = [f for f in candidates if "best-imn-epoch=" in f]
103
+ chosen = sorted(best_style or candidates)[-1]
104
+ return chosen
105
+
106
+
107
+ def load_imn_model(
108
+ model_type: str,
109
+ sampling_rate: int,
110
+ task: str,
111
+ ) -> Tuple[torch.nn.Module, str]:
112
+ key = (model_type, sampling_rate, task)
113
+ cached = MODEL_CACHE.get(key)
114
+ if cached and cached["model"] is not None:
115
+ return cached["model"], cached["device"]
116
+
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ filename = _resolve_ckpt_filename(model_type, sampling_rate, task)
119
+ ckpt_local = hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename)
120
+
121
+ if model_type == "single_linear":
122
+ model = sl_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device)
123
+ else:
124
+ model = cat_core.IMNLightning.load_from_checkpoint(ckpt_local, map_location=device)
125
+
126
+ model.eval()
127
+ model.to(device)
128
+ MODEL_CACHE[key] = {"path": ckpt_local, "model": model, "device": device}
129
+ return model, device
130
+
131
+
132
+ def load_data_binary(sampling_rate: int, task: str) -> Dict[str, Any]:
133
+ key = (sampling_rate, task)
134
+ if key in DATA_CACHE:
135
+ return DATA_CACHE[key]
136
+
137
+ path = DATA_FILES.get(key)
138
+ if path is None:
139
+ raise FileNotFoundError(f"No data file configured for (fs={sampling_rate}, task={task}).")
140
+ if not os.path.isfile(path):
141
+ raise FileNotFoundError(
142
+ f"Data file not found at '{path}'. "
143
+ "Upload a .npz with signals, labels, reports, age, sex, ecg_id."
144
+ )
145
+
146
+ with np.load(path, allow_pickle=True) as npz:
147
+ required = ["signals", "labels", "reports", "age", "sex", "ecg_id"]
148
+ missing = [k for k in required if k not in npz]
149
+ if missing:
150
+ raise KeyError(f"Data file '{path}' missing keys: {missing}")
151
+ data = {k: npz[k] for k in required}
152
+
153
+ DATA_CACHE[key] = data
154
+ return data
155
+
156
+
157
+ def on_load_records(
158
+ sampling_rate: int,
159
+ task: str,
160
+ state: Optional[dict],
161
+ ):
162
+ try:
163
+ data = load_data_binary(int(sampling_rate), task)
164
+ except Exception as e:
165
+ return (
166
+ f"Load error: {e}",
167
+ gr.update(choices=[], value=None),
168
+ state or {},
169
+ "—",
170
+ "—",
171
+ )
172
+
173
+ signals = data["signals"]
174
+ labels = data["labels"]
175
+ reports = data["reports"]
176
+ age = data["age"]
177
+ sex = data["sex"]
178
+ ecg_id = data["ecg_id"]
179
+
180
+ N, C, L = signals.shape
181
+ pos_class = TASK_TO_POS.get(task, "MI")
182
+
183
+ records: List[Dict[str, Any]] = []
184
+ for i in range(N):
185
+ gt = pos_class if float(labels[i]) >= 0.5 else "NORM"
186
+ records.append(
187
+ {
188
+ "index": int(i),
189
+ "ecg_id": int(ecg_id[i]),
190
+ "gt": gt,
191
+ "report": str(reports[i]) if reports is not None else "",
192
+ "age": age[i] if age is not None else "",
193
+ "sex": str(sex[i]) if sex is not None else "",
194
+ }
195
+ )
196
+
197
+ choices = [f"{r['index']} | {r['ecg_id']} | {r['gt']} | age {r['age']} {r['sex']}" for r in records]
198
+ value = choices[0] if choices else None
199
+ state = {
200
+ "records": records,
201
+ "fs": int(sampling_rate),
202
+ "task": task,
203
+ "pos_class": pos_class,
204
+ }
205
+ report = (records[0]["report"] or "(no clinical notes)") if records else "—"
206
+ gt = records[0]["gt"] if records else "—"
207
+ status = (
208
+ f"Loaded {N} examples (fs={sampling_rate}Hz, {pos_class} vs NORM, L={L})."
209
+ if N > 0
210
+ else "No examples found in data file."
211
+ )
212
+ return status, gr.update(choices=choices, value=value), state, report, gt
213
+
214
+
215
+ def on_select_record(choice: str, state: Optional[dict]):
216
+ if not state or not state.get("records") or not choice:
217
+ return "—", "—"
218
+ try:
219
+ idx = int(choice.split("|")[0].strip())
220
+ except Exception:
221
+ return "—", "—"
222
+ for r in state["records"]:
223
+ if r["index"] == idx:
224
+ return r["report"] or "(no clinical notes)", r["gt"]
225
+ return "—", "—"
226
+
227
+
228
+ def explain_record(
229
+ model_type: str,
230
+ sampling_rate: int,
231
+ task: str,
232
+ record_choice: str,
233
+ state: Optional[dict],
234
+ ):
235
+ err = "Select a record and Load records first.", None, "—", "—", "—"
236
+ if not state or not state.get("records") or not record_choice:
237
+ return err
238
+ try:
239
+ rec_idx = int(record_choice.split("|")[0].strip())
240
+ except Exception:
241
+ return err
242
+ rec = next((r for r in state["records"] if r["index"] == rec_idx), None)
243
+ if not rec:
244
+ return err
245
+
246
+ fs = state["fs"]
247
+ pos_class_name = state.get("pos_class", "MI")
248
+ report = rec["report"] or "(no clinical notes)"
249
+ gt = rec["gt"]
250
+
251
+ try:
252
+ data = load_data_binary(int(sampling_rate), task)
253
+ except Exception as e:
254
+ return f"Data error: {e}", None, report, gt, "—"
255
+ try:
256
+ model, device = load_imn_model(model_type, int(sampling_rate), task)
257
+ except Exception as e:
258
+ return f"Checkpoint error: {e}", None, report, gt, "—"
259
+
260
+ signals = data["signals"]
261
+ if rec_idx < 0 or rec_idx >= signals.shape[0]:
262
+ return f"Invalid record index {rec_idx}.", None, report, gt, "—"
263
+
264
+ x = signals[rec_idx] # [12, L]
265
+ if x.shape[0] != 12:
266
+ return f"Expected 12 leads, got {x.shape[0]}.", None, report, gt, "—"
267
+
268
+ signal_len_model = int(model.hparams["signal_len"])
269
+ if x.shape[1] != signal_len_model:
270
+ return (
271
+ f"ECG length {x.shape[1]} != model {signal_len_model}. "
272
+ "Ensure data binaries match the training window length.",
273
+ None,
274
+ report,
275
+ gt,
276
+ "—",
277
+ )
278
+
279
+ x = zscore_per_lead(x)
280
+ x_t = torch.from_numpy(x).float().unsqueeze(0).to(device)
281
+
282
+ with torch.no_grad():
283
+ logits, gen_w, gen_b = model.model(x_t)
284
+ if model_type == "single_linear":
285
+ logit = logits.squeeze()
286
+ prob_pos = float(torch.sigmoid(logit).item())
287
+ w_used = gen_w[0, 0, :, :].cpu().numpy()
288
+ else:
289
+ probs = torch.softmax(logits, dim=1)
290
+ prob_pos = float(probs[0, 1].item())
291
+ w_used = gen_w[0, 1, :, :].cpu().numpy()
292
+
293
+ x_np = x.astype(np.float64)
294
+ impact = w_used * x_np # [12, L]
295
+
296
+ # Window/stride heuristic by sampling rate
297
+ window = 50 if int(sampling_rate) == 100 else 250
298
+ stride = window // 2
299
+ seg_hm = sl_core.imn_weights_to_segments(impact, window=window, stride=stride) # [12, T]
300
+
301
+ # Build simple figure: heatmap + 12 ECG traces
302
+ L = x_np.shape[1]
303
+ T = seg_hm.shape[1]
304
+
305
+ fig = plt.figure(figsize=(11, 10))
306
+ gs = fig.add_gridspec(14, 1, height_ratios=[2] + [1] * 12 + [0.5])
307
+
308
+ ax0 = fig.add_subplot(gs[0, 0])
309
+ im = ax0.imshow(seg_hm, aspect="auto", vmin=0.0, vmax=1.0, cmap="Reds")
310
+ ax0.set_yticks(range(12))
311
+ ax0.set_yticklabels(LEAD_NAMES)
312
+ ax0.set_xlabel(f"Segments (window={window}, stride={stride}, fs={fs}Hz)")
313
+ prob_str = f"P({pos_class_name})={prob_pos:.3f}"
314
+ pred = pos_class_name if prob_pos >= 0.5 else "NORM"
315
+ ax0.set_title(f"IMN Intrinsic Explanation | {pred} | {prob_str}")
316
+ fig.colorbar(im, ax=ax0, fraction=0.02, pad=0.01)
317
+
318
+ for lead in range(12):
319
+ ax = fig.add_subplot(gs[lead + 1, 0])
320
+ ax.plot(x_np[lead], linewidth=0.8, color="black", alpha=0.7)
321
+ ax.set_xlim(0, L - 1)
322
+ ax.set_ylabel(LEAD_NAMES[lead], rotation=0, labelpad=15, va="center")
323
+ ax.set_xticks([])
324
+
325
+ axf = fig.add_subplot(gs[13, 0])
326
+ axf.axis("off")
327
+ axf.text(
328
+ 0.5,
329
+ 0.5,
330
+ "Heatmap: |w(x)·x| aggregated over segments (higher = more contribution towards POS_CLASS).",
331
+ fontsize=9,
332
+ ha="center",
333
+ va="center",
334
+ wrap=True,
335
+ transform=axf.transAxes,
336
+ )
337
+
338
+ summary = (
339
+ f"**{pred}** | P({pos_class_name}) = {prob_pos:.3f} | "
340
+ f"Ground truth: **{gt}** | fs={fs}Hz, window={window}, stride={stride}"
341
+ )
342
+ return summary, fig, report, gt, f"{rec['ecg_id']}"
343
+
344
+
345
+ def main():
346
+ demo = gr.Blocks(
347
+ title="MesomorphicECG XAI (IMN categorical + single-linear)",
348
+ theme=gr.themes.Soft(),
349
+ )
350
+ with demo:
351
+ gr.Markdown(
352
+ "# MesomorphicECG XAI\n"
353
+ "Interactive XAI viewer for Interpretable Mesomorphic Networks (IMN) on PTB-XL ECGs.\n\n"
354
+ "- Models and checkpoints from "
355
+ "[SEARCH-IHI/mesomorphicECG](https://huggingface.co/SEARCH-IHI/mesomorphicECG).\n"
356
+ "- Data samples loaded from binary `.npz` files stored in this Space.\n"
357
+ "- Heatmaps show segment-wise IMN contribution per lead."
358
+ )
359
+
360
+ with gr.Row():
361
+ sampling_rate = gr.Radio(
362
+ label="Sampling rate",
363
+ choices=[100, 500],
364
+ value=500,
365
+ )
366
+ model_type = gr.Radio(
367
+ label="Model type",
368
+ choices=["single_linear", "categorical"],
369
+ value="single_linear",
370
+ info="single_linear: single linear head; categorical: 2-class head.",
371
+ )
372
+ task = gr.Radio(
373
+ label="Task (positive class vs NORM)",
374
+ choices=list(TASK_TO_POS.keys()),
375
+ value="norm_vs_mi",
376
+ )
377
+ load_btn = gr.Button("Load records", variant="secondary")
378
+
379
+ load_status = gr.Markdown()
380
+ records_state = gr.State(value=None)
381
+
382
+ with gr.Row():
383
+ record_dd = gr.Dropdown(
384
+ label="Record (index | ecg_id | GT | age sex)",
385
+ choices=[],
386
+ value=None,
387
+ )
388
+
389
+ with gr.Row():
390
+ clinical_notes = gr.Textbox(
391
+ label="Clinical notes (report)",
392
+ value="",
393
+ lines=4,
394
+ max_lines=8,
395
+ interactive=False,
396
+ )
397
+ ground_truth = gr.Textbox(
398
+ label="Ground truth",
399
+ value="—",
400
+ interactive=False,
401
+ )
402
+
403
+ load_btn.click(
404
+ fn=on_load_records,
405
+ inputs=[sampling_rate, task, records_state],
406
+ outputs=[load_status, record_dd, records_state, clinical_notes, ground_truth],
407
+ )
408
+ record_dd.change(
409
+ fn=on_select_record,
410
+ inputs=[record_dd, records_state],
411
+ outputs=[clinical_notes, ground_truth],
412
+ )
413
+
414
+ run_btn = gr.Button("Run IMN explanation", variant="primary")
415
+
416
+ out_summary = gr.Markdown()
417
+ out_plot = gr.Plot()
418
+ out_notes = gr.Textbox(label="Clinical notes", lines=3, interactive=False)
419
+ out_gt = gr.Textbox(label="Ground truth", interactive=False)
420
+ out_meta = gr.Textbox(label="ECG ID", interactive=False)
421
+
422
+ run_btn.click(
423
+ fn=explain_record,
424
+ inputs=[model_type, sampling_rate, task, record_dd, records_state],
425
+ outputs=[out_summary, out_plot, out_notes, out_gt, out_meta],
426
+ )
427
+
428
+ demo.launch()
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
433
+
categorical_imn_core.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core categorical IMN model definition for mesomorphicECG.
3
+
4
+ This is a trimmed-down subset of
5
+ `script_02022026_v7_IMN_GM_2_with_transition_net.py` containing only
6
+ the pieces needed for inference:
7
+
8
+ - ECG_IMN (categorical / 2-class hypernetwork)
9
+ - IMNLightning (PyTorch Lightning wrapper)
10
+ - imn_weights_to_segments (segment-wise aggregation helper)
11
+
12
+ These definitions are compatible with checkpoints uploaded to
13
+ `SEARCH-IHI/mesomorphicECG` under:
14
+ categorical_imn_100hz/<task>/
15
+ categorical_imn_500hz/<task>/
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import pytorch_lightning as pl
25
+ from sklearn.metrics import roc_auc_score
26
+
27
+
28
+ DEFAULT_LEAD_NAMES = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
29
+
30
+
31
+ def imn_weights_to_segments(impact_12L: np.ndarray, window: int, stride: int) -> np.ndarray:
32
+ """
33
+ Aggregates point-wise feature attribution (Impact) into segments for cleaner visualization.
34
+
35
+ impact_12L: [12, L]
36
+ Returns: [12, T] heatmap normalized to [0, 1] per-record.
37
+ """
38
+ assert impact_12L.ndim == 2
39
+ L = impact_12L.shape[1]
40
+ T = (L - window) // stride + 1
41
+ seg = np.zeros((12, T), dtype=np.float32)
42
+
43
+ for t in range(T):
44
+ s = t * stride
45
+ e = min(s + window, L)
46
+ seg[:, t] = np.abs(impact_12L[:, s:e]).mean(axis=1)
47
+
48
+ mx = seg.max() + 1e-9
49
+ seg = seg / mx
50
+ return seg
51
+
52
+
53
+ class ECG_IMN(nn.Module):
54
+ """
55
+ Interpretable Mesomorphic Network for ECG with Transition Network (categorical).
56
+
57
+ Generates weights W [B, num_classes, 12, L] and biases b [B, num_classes].
58
+ Final logits for each class k: logits_k = sum(W_k * x) + b_k.
59
+ """
60
+
61
+ def __init__(self, input_channels: int = 12, signal_len: int = 1000, num_classes: int = 2, dropout: float = 0.2):
62
+ super().__init__()
63
+ self.num_classes = num_classes
64
+ self.C = input_channels
65
+ self.L = signal_len
66
+
67
+ # Hypernetwork backbone (encoder), input: [B, 1, 12, L]
68
+ self.conv1 = nn.Sequential(
69
+ nn.Conv2d(1, 16, kernel_size=(3, 15), padding=(1, 7), bias=False),
70
+ nn.BatchNorm2d(16),
71
+ nn.GELU(),
72
+ ) # -> [B, 16, 12, L]
73
+
74
+ self.conv2 = nn.Sequential(
75
+ nn.Conv2d(16, 32, kernel_size=(3, 15), padding=(1, 7), bias=False),
76
+ nn.BatchNorm2d(32),
77
+ nn.GELU(),
78
+ nn.MaxPool2d(kernel_size=(1, 2)),
79
+ ) # -> [B, 32, 12, L/2]
80
+
81
+ self.conv3 = nn.Sequential(
82
+ nn.Conv2d(32, 64, kernel_size=(3, 15), padding=(1, 7), bias=False),
83
+ nn.BatchNorm2d(64),
84
+ nn.GELU(),
85
+ nn.MaxPool2d(kernel_size=(1, 2)),
86
+ ) # -> [B, 64, 12, L/4]
87
+
88
+ self.dropout = nn.Dropout(dropout)
89
+
90
+ # Transition network: upsample to generate W [B, num_classes, 12, L]
91
+ self.transition = nn.Sequential(
92
+ # L/4 -> L/2, 64 -> 32
93
+ nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
94
+ nn.BatchNorm2d(32),
95
+ nn.GELU(),
96
+ nn.Upsample(scale_factor=(1, 2), mode="nearest"),
97
+ # L/2 -> L, 32 -> 16
98
+ nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=False),
99
+ nn.BatchNorm2d(16),
100
+ nn.GELU(),
101
+ nn.Upsample(scale_factor=(1, 2), mode="nearest"),
102
+ # Final projection to num_classes channels (weights)
103
+ nn.Conv2d(16, num_classes, kernel_size=3, padding=1, bias=True),
104
+ )
105
+
106
+ # Bias generator: class-wise bias from global pooled features
107
+ self.bias_pool = nn.AdaptiveAvgPool2d((1, 1))
108
+ self.bias_head = nn.Linear(64, num_classes)
109
+
110
+ def forward(self, x: torch.Tensor):
111
+ """
112
+ x: [B, 12, L]
113
+ Returns:
114
+ logits: [B, num_classes]
115
+ generated_w: [B, num_classes, 12, L]
116
+ generated_b: [B, num_classes, 1]
117
+ """
118
+ B, C, L = x.shape
119
+
120
+ feat = x.unsqueeze(1) # [B, 1, 12, L]
121
+ feat = self.conv1(feat)
122
+ feat = self.conv2(feat)
123
+ feat = self.conv3(feat) # [B, 64, 12, L/4]
124
+ feat = self.dropout(feat)
125
+
126
+ # Weights W: [B, num_classes, 12, L]
127
+ generated_w = self.transition(feat)
128
+
129
+ # Bias b: [B, num_classes]
130
+ b_feat = self.bias_pool(feat).view(B, -1)
131
+ generated_b = self.bias_head(b_feat)
132
+
133
+ x_expanded = x.unsqueeze(1) # [B, 1, 12, L]
134
+ weighted_input = generated_w * x_expanded # [B, num_classes, 12, L]
135
+ logits = weighted_input.sum(dim=(2, 3)) + generated_b # [B, num_classes]
136
+
137
+ return logits, generated_w, generated_b.unsqueeze(-1)
138
+
139
+
140
+ class IMNLightning(pl.LightningModule):
141
+ """
142
+ PyTorch Lightning wrapper for ECG_IMN (categorical, 2-class).
143
+
144
+ Matches the training-time definition used in
145
+ `script_02022026_v7_IMN_GM_2_with_transition_net.py`
146
+ so `IMNLightning.load_from_checkpoint(...)` works for inference.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ input_channels: int,
152
+ signal_len: int,
153
+ dropout: float = 0.2,
154
+ lr: float = 1e-3,
155
+ weight_decay: float = 1e-4,
156
+ lambda_l1: float = 1e-4,
157
+ class_weights: list[float] | None = None,
158
+ scheduler_type: str | None = "cosine",
159
+ scheduler_params: dict | None = None,
160
+ ):
161
+ super().__init__()
162
+ self.save_hyperparameters()
163
+
164
+ self.model = ECG_IMN(
165
+ input_channels=input_channels,
166
+ signal_len=signal_len,
167
+ dropout=dropout,
168
+ )
169
+
170
+ self.lr = lr
171
+ self.weight_decay = weight_decay
172
+ self.lambda_l1 = lambda_l1
173
+ self._class_weights = class_weights
174
+
175
+ self.scheduler_type = scheduler_type
176
+ self.scheduler_params = scheduler_params or {}
177
+
178
+ self.val_probs: list[torch.Tensor] = []
179
+ self.val_y: list[torch.Tensor] = []
180
+
181
+ def configure_optimizers(self):
182
+ optimizer = torch.optim.AdamW(
183
+ self.parameters(),
184
+ lr=self.lr,
185
+ weight_decay=self.weight_decay,
186
+ )
187
+
188
+ if self.scheduler_type is None or self.scheduler_type == "none":
189
+ return optimizer
190
+
191
+ if self.scheduler_type == "cosine":
192
+ max_epochs = getattr(self.trainer, "max_epochs", None) or 100
193
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
194
+ optimizer,
195
+ T_max=max_epochs,
196
+ **self.scheduler_params,
197
+ )
198
+ elif self.scheduler_type == "step":
199
+ scheduler = torch.optim.lr_scheduler.StepLR(
200
+ optimizer,
201
+ step_size=self.scheduler_params.get("step_size", 10),
202
+ gamma=self.scheduler_params.get("gamma", 0.1),
203
+ **{k: v for k, v in self.scheduler_params.items() if k not in ["step_size", "gamma"]},
204
+ )
205
+ elif self.scheduler_type == "reduce_on_plateau":
206
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
207
+ optimizer,
208
+ mode="max",
209
+ factor=self.scheduler_params.get("factor", 0.5),
210
+ patience=self.scheduler_params.get("patience", 5),
211
+ **{k: v for k, v in self.scheduler_params.items() if k not in ["factor", "patience"]},
212
+ )
213
+ elif self.scheduler_type == "cosine_restarts":
214
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
215
+ optimizer,
216
+ T_0=self.scheduler_params.get("T_0", 10),
217
+ T_mult=self.scheduler_params.get("T_mult", 2),
218
+ **{k: v for k, v in self.scheduler_params.items() if k not in ["T_0", "T_mult"]},
219
+ )
220
+ else:
221
+ return optimizer
222
+
223
+ if self.scheduler_type == "reduce_on_plateau":
224
+ return {
225
+ "optimizer": optimizer,
226
+ "lr_scheduler": {
227
+ "scheduler": scheduler,
228
+ "monitor": "val_auc",
229
+ },
230
+ }
231
+ else:
232
+ return {
233
+ "optimizer": optimizer,
234
+ "lr_scheduler": scheduler,
235
+ }
236
+
237
+ def _ce_weight(self):
238
+ if self._class_weights is None:
239
+ return None
240
+ return torch.tensor(self._class_weights, dtype=torch.float32, device=self.device)
241
+
242
+ def training_step(self, batch, batch_idx):
243
+ x, y = batch
244
+ logits, gen_w, gen_b = self.model(x)
245
+
246
+ ce_loss = F.cross_entropy(logits, y, weight=self._ce_weight())
247
+ l1_loss = gen_w.abs().mean()
248
+ total_loss = ce_loss + (self.lambda_l1 * l1_loss)
249
+
250
+ pred = logits.argmax(dim=1)
251
+ acc = (pred == y).float().mean()
252
+
253
+ self.log("train_loss", total_loss, on_step=False, on_epoch=True, prog_bar=True)
254
+ self.log("train_ce", ce_loss, on_step=False, on_epoch=True)
255
+ self.log("train_l1", l1_loss, on_step=False, on_epoch=True)
256
+ self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
257
+ return total_loss
258
+
259
+ def validation_step(self, batch, batch_idx):
260
+ x, y = batch
261
+ logits, gen_w, _ = self.model(x)
262
+
263
+ ce_loss = F.cross_entropy(logits, y, weight=self._ce_weight())
264
+ prob = torch.softmax(logits, dim=1)[:, 1]
265
+ pred = logits.argmax(dim=1)
266
+ acc = (pred == y).float().mean()
267
+
268
+ self.log("val_loss", ce_loss, on_step=False, on_epoch=True, prog_bar=True)
269
+ self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
270
+
271
+ self.val_probs.append(prob.detach().cpu())
272
+ self.val_y.append(y.detach().cpu())
273
+
274
+ def on_validation_epoch_end(self):
275
+ if not self.val_y:
276
+ return
277
+ y_true = torch.cat(self.val_y)
278
+ y_score = torch.cat(self.val_probs)
279
+ auc = simple_auc_roc(y_true.float(), y_score.float())
280
+ self.log("val_auc", auc, on_step=False, on_epoch=True, prog_bar=True)
281
+ self.val_probs.clear()
282
+ self.val_y.clear()
283
+
284
+ def test_step(self, batch, batch_idx):
285
+ x, y = batch
286
+ logits, _, _ = self.model(x)
287
+ ce_loss = F.cross_entropy(logits, y, weight=self._ce_weight())
288
+ prob = torch.softmax(logits, dim=1)[:, 1]
289
+ pred = logits.argmax(dim=1)
290
+ acc = (pred == y).float().mean()
291
+
292
+ self.log("test_loss", ce_loss, on_step=False, on_epoch=True)
293
+ self.log("test_acc", acc, on_step=False, on_epoch=True)
294
+ return {"y": y.detach().cpu(), "p": prob.detach().cpu()}
295
+
296
+
297
+ @torch.no_grad()
298
+ def simple_auc_roc(y_true: torch.Tensor, y_score: torch.Tensor) -> float:
299
+ """
300
+ Simple AUROC helper, matching the training script.
301
+ """
302
+ y_true = y_true.detach().cpu().float()
303
+ y_score = y_score.detach().cpu().float()
304
+ if y_true.min() == y_true.max():
305
+ return float("nan")
306
+ return float(roc_auc_score(y_true.numpy(), y_score.numpy()))
307
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ numpy
4
+ matplotlib
5
+ pytorch-lightning
6
+ scikit-learn
7
+ huggingface-hub
8
+
single_linear_imn_core.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core single-linear IMN model definition for mesomorphicECG.
3
+
4
+ This file is a lightweight subset of
5
+ `script_02022026_v7_IMN_GM_2_with_transition_net_with_one_linear_eq.py`
6
+ containing only the pieces needed for inference:
7
+
8
+ - ECG_IMN (single-linear hypernetwork)
9
+ - IMNLightning (PyTorch Lightning wrapper)
10
+ - imn_weights_to_segments (segment-wise aggregation helper)
11
+
12
+ These definitions are compatible with checkpoints uploaded to
13
+ `SEARCH-IHI/mesomorphicECG` under:
14
+ single_linear_imn_100hz/<task>/
15
+ single_linear_imn_500hz/<task>/
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import pytorch_lightning as pl
25
+ from sklearn.metrics import roc_auc_score
26
+
27
+
28
+ DEFAULT_LEAD_NAMES = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
29
+
30
+
31
+ def imn_weights_to_segments(impact_12L: np.ndarray, window: int, stride: int) -> np.ndarray:
32
+ """
33
+ Aggregates point-wise feature attribution (Impact) into segments.
34
+
35
+ impact_12L: [12, L] array of signed contributions (e.g. w * x).
36
+ Returns: [12, T] heatmap normalized to [0, 1] per-record.
37
+ """
38
+ assert impact_12L.ndim == 2
39
+ L = impact_12L.shape[1]
40
+ T = (L - window) // stride + 1
41
+ seg = np.zeros((12, T), dtype=np.float32)
42
+
43
+ for t in range(T):
44
+ s = t * stride
45
+ e = min(s + window, L)
46
+ seg[:, t] = np.abs(impact_12L[:, s:e]).mean(axis=1)
47
+
48
+ mx = seg.max() + 1e-9
49
+ seg = seg / mx
50
+ return seg
51
+
52
+
53
+ class ECG_IMN(nn.Module):
54
+ """
55
+ Interpretable Mesomorphic Network for ECG (single-linear output).
56
+
57
+ Generates ONE set of weights W [B, 1, 12, L] and ONE bias b [B, 1].
58
+ Prediction: logit = sum(W * x) + b, P(pos) = sigmoid(logit).
59
+ """
60
+
61
+ def __init__(self, input_channels: int = 12, signal_len: int = 1000, dropout: float = 0.2):
62
+ super().__init__()
63
+ self.C = input_channels
64
+ self.L = signal_len
65
+
66
+ output_dim = 1 # single linear output
67
+
68
+ # Hypernetwork backbone (encoder), input: [B, 1, 12, L]
69
+ self.conv1 = nn.Sequential(
70
+ nn.Conv2d(1, 16, kernel_size=(3, 15), padding=(1, 7), bias=False),
71
+ nn.BatchNorm2d(16),
72
+ nn.GELU(),
73
+ ) # -> [B, 16, 12, L]
74
+
75
+ self.conv2 = nn.Sequential(
76
+ nn.Conv2d(16, 32, kernel_size=(3, 15), padding=(1, 7), bias=False),
77
+ nn.BatchNorm2d(32),
78
+ nn.GELU(),
79
+ nn.MaxPool2d(kernel_size=(1, 2)),
80
+ ) # -> [B, 32, 12, L/2]
81
+
82
+ self.conv3 = nn.Sequential(
83
+ nn.Conv2d(32, 64, kernel_size=(3, 15), padding=(1, 7), bias=False),
84
+ nn.BatchNorm2d(64),
85
+ nn.GELU(),
86
+ nn.MaxPool2d(kernel_size=(1, 2)),
87
+ ) # -> [B, 64, 12, L/4]
88
+
89
+ self.dropout = nn.Dropout(dropout)
90
+
91
+ # Transition network: upsample to generate W [B, 1, 12, L]
92
+ self.transition = nn.Sequential(
93
+ # L/4 -> L/2
94
+ nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
95
+ nn.BatchNorm2d(32),
96
+ nn.GELU(),
97
+ nn.Upsample(scale_factor=(1, 2), mode="nearest"),
98
+ # L/2 -> L
99
+ nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=False),
100
+ nn.BatchNorm2d(16),
101
+ nn.GELU(),
102
+ nn.Upsample(scale_factor=(1, 2), mode="nearest"),
103
+ # Final projection to 1 channel (weights)
104
+ nn.Conv2d(16, output_dim, kernel_size=3, padding=1, bias=True),
105
+ )
106
+
107
+ # Bias generator: scalar bias from global pooled features
108
+ self.bias_pool = nn.AdaptiveAvgPool2d((1, 1))
109
+ self.bias_head = nn.Linear(64, output_dim)
110
+
111
+ def forward(self, x: torch.Tensor):
112
+ """
113
+ x: [B, 12, L]
114
+ Returns:
115
+ logits: [B, 1]
116
+ generated_w: [B, 1, 12, L]
117
+ generated_b: [B, 1, 1]
118
+ """
119
+ B, C, L = x.shape
120
+
121
+ feat = x.unsqueeze(1) # [B, 1, 12, L]
122
+ feat = self.conv1(feat)
123
+ feat = self.conv2(feat)
124
+ feat = self.conv3(feat)
125
+ feat = self.dropout(feat)
126
+
127
+ # Weights W: [B, 1, 12, L]
128
+ generated_w = self.transition(feat)
129
+
130
+ # Bias b: [B, 1]
131
+ b_feat = self.bias_pool(feat).view(B, -1)
132
+ generated_b = self.bias_head(b_feat)
133
+
134
+ # Single-linear logit
135
+ x_expanded = x.unsqueeze(1) # [B, 1, 12, L]
136
+ weighted_input = generated_w * x_expanded
137
+ logits = weighted_input.sum(dim=(2, 3)) + generated_b # [B, 1]
138
+
139
+ return logits, generated_w, generated_b.unsqueeze(-1)
140
+
141
+
142
+ class IMNLightning(pl.LightningModule):
143
+ """
144
+ PyTorch Lightning wrapper for ECG_IMN (single-linear).
145
+
146
+ This class matches the training-time definition used for checkpoints
147
+ in `script_02022026_v7_IMN_GM_2_with_transition_net_with_one_linear_eq.py`,
148
+ so that `IMNLightning.load_from_checkpoint(...)` works for inference.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ input_channels: int,
154
+ signal_len: int,
155
+ dropout: float = 0.2,
156
+ lr: float = 1e-3,
157
+ weight_decay: float = 1e-4,
158
+ lambda_l1: float = 1e-4,
159
+ pos_weight: float | None = None,
160
+ scheduler_type: str | None = "cosine",
161
+ scheduler_params: dict | None = None,
162
+ ):
163
+ super().__init__()
164
+ self.save_hyperparameters()
165
+
166
+ self.model = ECG_IMN(
167
+ input_channels=input_channels,
168
+ signal_len=signal_len,
169
+ dropout=dropout,
170
+ )
171
+
172
+ self.lr = lr
173
+ self.weight_decay = weight_decay
174
+ self.lambda_l1 = lambda_l1
175
+ self.pos_weight_val = pos_weight
176
+
177
+ self.scheduler_type = scheduler_type
178
+ self.scheduler_params = scheduler_params or {}
179
+
180
+ self.val_probs: list[torch.Tensor] = []
181
+ self.val_y: list[torch.Tensor] = []
182
+
183
+ def configure_optimizers(self):
184
+ optimizer = torch.optim.AdamW(
185
+ self.parameters(),
186
+ lr=self.lr,
187
+ weight_decay=self.weight_decay,
188
+ )
189
+
190
+ if self.scheduler_type is None or self.scheduler_type == "none":
191
+ return optimizer
192
+
193
+ if self.scheduler_type == "cosine":
194
+ max_epochs = getattr(self.trainer, "max_epochs", None) or 100
195
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
196
+ optimizer,
197
+ T_max=max_epochs,
198
+ **self.scheduler_params,
199
+ )
200
+ elif self.scheduler_type == "step":
201
+ scheduler = torch.optim.lr_scheduler.StepLR(
202
+ optimizer,
203
+ step_size=self.scheduler_params.get("step_size", 10),
204
+ gamma=self.scheduler_params.get("gamma", 0.1),
205
+ **{k: v for k, v in self.scheduler_params.items() if k not in ["step_size", "gamma"]},
206
+ )
207
+ elif self.scheduler_type == "reduce_on_plateau":
208
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
209
+ optimizer,
210
+ mode="max",
211
+ factor=self.scheduler_params.get("factor", 0.5),
212
+ patience=self.scheduler_params.get("patience", 5),
213
+ **{k: v for k, v in self.scheduler_params.items() if k not in ["factor", "patience"]},
214
+ )
215
+ else:
216
+ return optimizer
217
+
218
+ if self.scheduler_type == "reduce_on_plateau":
219
+ return {
220
+ "optimizer": optimizer,
221
+ "lr_scheduler": {
222
+ "scheduler": scheduler,
223
+ "monitor": "val_auc",
224
+ },
225
+ }
226
+ else:
227
+ return {
228
+ "optimizer": optimizer,
229
+ "lr_scheduler": scheduler,
230
+ }
231
+
232
+ def _get_pos_weight(self):
233
+ if self.pos_weight_val is None:
234
+ return None
235
+ return torch.tensor([self.pos_weight_val], device=self.device)
236
+
237
+ def training_step(self, batch, batch_idx):
238
+ x, y = batch # y: [B] float
239
+ logits, gen_w, gen_b = self.model(x)
240
+ logits = logits.squeeze(1) # [B]
241
+
242
+ bce_loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=self._get_pos_weight())
243
+ l1_loss = gen_w.abs().mean()
244
+ total_loss = bce_loss + (self.lambda_l1 * l1_loss)
245
+
246
+ probs = torch.sigmoid(logits)
247
+ preds = (probs > 0.5).float()
248
+ acc = (preds == y).float().mean()
249
+
250
+ self.log("train_loss", total_loss, on_step=False, on_epoch=True, prog_bar=True)
251
+ self.log("train_bce", bce_loss, on_step=False, on_epoch=True)
252
+ self.log("train_l1", l1_loss, on_step=False, on_epoch=True)
253
+ self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
254
+ return total_loss
255
+
256
+ def validation_step(self, batch, batch_idx):
257
+ x, y = batch
258
+ logits, gen_w, _ = self.model(x)
259
+ logits = logits.squeeze(1)
260
+
261
+ bce_loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=self._get_pos_weight())
262
+ prob = torch.sigmoid(logits)
263
+ pred = (prob > 0.5).float()
264
+ acc = (pred == y).float().mean()
265
+
266
+ self.log("val_loss", bce_loss, on_step=False, on_epoch=True, prog_bar=True)
267
+ self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
268
+
269
+ self.val_probs.append(prob.detach().cpu())
270
+ self.val_y.append(y.detach().cpu())
271
+
272
+ def on_validation_epoch_end(self):
273
+ if not self.val_y:
274
+ return
275
+ y_true = torch.cat(self.val_y)
276
+ y_score = torch.cat(self.val_probs)
277
+ auc = simple_auc_roc(y_true, y_score)
278
+ self.log("val_auc", auc, on_step=False, on_epoch=True, prog_bar=True)
279
+ self.val_probs.clear()
280
+ self.val_y.clear()
281
+
282
+ def test_step(self, batch, batch_idx):
283
+ x, y = batch
284
+ logits, _, _ = self.model(x)
285
+ logits = logits.squeeze(1)
286
+
287
+ bce_loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=self._get_pos_weight())
288
+ prob = torch.sigmoid(logits)
289
+ pred = (prob > 0.5).float()
290
+ acc = (pred == y).float().mean()
291
+
292
+ self.log("test_loss", bce_loss, on_step=False, on_epoch=True)
293
+ self.log("test_acc", acc, on_step=False, on_epoch=True)
294
+ return {"y": y.detach().cpu(), "p": prob.detach().cpu()}
295
+
296
+
297
+ @torch.no_grad()
298
+ def simple_auc_roc(y_true: torch.Tensor, y_score: torch.Tensor) -> float:
299
+ """
300
+ Simple AUROC helper, matching the training script.
301
+ """
302
+ y_true = y_true.detach().cpu().float()
303
+ y_score = y_score.detach().cpu().float()
304
+ if y_true.min() == y_true.max():
305
+ return float("nan")
306
+ return float(roc_auc_score(y_true.numpy(), y_score.numpy()))
307
+