DiligentPenguinn commited on
Commit
370a2e0
·
verified ·
1 Parent(s): 56040ff

Add model details and instructions

Browse files
Files changed (3) hide show
  1. README.md +129 -3
  2. inference_loader.py +349 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,3 +1,129 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ECG Image Classifier (MoE and MLP) on MedSigLIP Embeddings
2
+
3
+ This repository provides two PyTorch ECG classifier checkpoints trained on top of frozen MedSigLIP image embeddings:
4
+
5
+ - `moe_classifier_medsiglip.pt`: Mixture-of-Experts (MoE) classifier
6
+ - `mlp_classifier_medsiglip.pt`: Dense feedforward (MLP) classifier
7
+
8
+ These checkpoints expect embeddings produced by:
9
+
10
+ - `google/medsiglip-448`
11
+
12
+ The repository contains only the classifier heads. MedSigLIP weights are not included and must be obtained separately under Google’s license.
13
+
14
+ ---
15
+
16
+ ## Motivation
17
+
18
+ This work was developed as part of the Google MedGemma Impact Challenge:
19
+ https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview
20
+
21
+ The goal is to build a lightweight, deployable ECG image classifier for chronic care screening, especially in low-resource clinical settings where ECG is often the most accessible diagnostic modality.
22
+
23
+ ---
24
+
25
+ ## Task and Data
26
+
27
+ We formulate a supervised multi-label image classification task on 12-lead ECGs with five diagnostic categories:
28
+
29
+ - NORM (normal)
30
+ - MI (myocardial infarction)
31
+ - STTC (ST-T changes)
32
+ - CD (conduction disturbances)
33
+ - HYP (hypertrophy)
34
+
35
+ Training data combines:
36
+
37
+ - PTB-XL, a large-scale dataset of raw 12-lead ECG waveforms in WFDB format with 16-bit precision
38
+ - A supplementary ECG image dataset
39
+
40
+ To enable image-based classification, raw PTB-XL waveforms are converted into realistic print-style ECG images using the open-source ECG image generator by Rahimi et al. This yields approximately 21,000 synthetic ECG images, which are combined with 713 real ECG images from the supplementary dataset.
41
+
42
+ ---
43
+
44
+ ## Model and Training
45
+
46
+ ECG images are first encoded using MedSigLIP to obtain fixed-dimensional visual embeddings. Two lightweight classifiers are trained on top of these embeddings:
47
+
48
+ - A dense feedforward network (MLP)
49
+ - A Mixture-of-Experts (MoE) classifier
50
+
51
+ The dataset is split into 60 percent training, 20 percent validation, and 20 percent testing. Both models are optimized with Adam using a learning rate of 1e-4 and weight decay of 1e-5. The MoE model additionally uses a load-balancing regularization term with lambda set to 0.1.
52
+
53
+ For multi-label prediction, a uniform decision threshold of 0.3 is applied across all classes.
54
+
55
+ ---
56
+
57
+ ## Results
58
+
59
+ On the held-out test set, the MoE classifier consistently outperforms the MLP baseline across all metrics. It achieves:
60
+
61
+ - Lower Hamming loss: 0.167 vs 0.235
62
+ - Higher ROC-AUC:
63
+ - Micro: 0.891 vs 0.827
64
+ - Macro: 0.879 vs 0.808
65
+ - Higher F1 scores:
66
+ - Micro: 0.70 vs 0.61
67
+ - Macro: 0.67 vs 0.58
68
+
69
+ Per-class F1 improves across all five diagnostic categories, with the largest gains observed for myocardial infarction and hypertrophy. Confusion matrix analysis indicates that the MLP baseline tends to trade precision for recall, producing more false positives and a lower overall F1. For this reason, the MoE classifier is used in the final application.
70
+
71
+ ---
72
+
73
+ ## Practical Implications
74
+
75
+ Compared to using MedGemma alone, the MedSigLIP plus classifier pipeline provides more structured and reliable ECG predictions. In addition to discrete labels, the classifier outputs calibrated confidence scores. This supports threshold-based screening and triage, which is particularly useful in chronic care workflows and remote clinics where rapid ECG assessment can help prioritize referrals.
76
+
77
+ ---
78
+
79
+ ## How to Use
80
+
81
+ ### 1) Install dependencies
82
+
83
+ ```bash
84
+ pip install -r requirements.txt
85
+ ```
86
+
87
+ ### 2) Run inference
88
+
89
+ Single image with the MoE checkpoint:
90
+
91
+ ```bash
92
+ python inference_loader.py \
93
+ --ckpt ./moe_classifier_medsiglip.pt \
94
+ --image ./sample_ecg.png \
95
+ --out ./preds_moe.json
96
+ ```
97
+
98
+ Batch inference on a folder with the MLP checkpoint:
99
+
100
+ ```bash
101
+ python inference_loader.py \
102
+ --ckpt ./mlp_classifier_medsiglip.pt \
103
+ --folder ./images \
104
+ --out ./preds_mlp.json
105
+ ```
106
+
107
+ ### 3) Optional arguments
108
+
109
+ - `--model_id` (default: `google/medsiglip-448`)
110
+ - `--device auto|cpu|cuda`
111
+ - `--batch_size 16`
112
+ - `--threshold 0.3` (overrides the checkpoint threshold)
113
+ - `--hf_token <token>` (or set `HF_TOKEN` as an environment variable)
114
+
115
+ ### 4) Outputs
116
+
117
+ The inference script returns:
118
+
119
+ - `scores_by_class`: confidence scores for each diagnostic class
120
+ - `predicted_labels`: labels above the decision threshold
121
+ - `summary`: run metadata including checkpoint, model type, device, and embedding dimensions
122
+
123
+ ---
124
+
125
+ ## References
126
+
127
+ [1] PTB-XL dataset
128
+ [2] Supplementary ECG image dataset used in this project
129
+ [3] Rahimi et al., open-source ECG image generator
inference_loader.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone inference loader for ECG classifier checkpoints.
3
+
4
+ Supports:
5
+ - MoE checkpoints (experts.* + gate.* keys)
6
+ - MLP checkpoints (fc1/fc2/out keys)
7
+
8
+ Usage examples:
9
+ python inference_loader.py --ckpt ./moe_classifier_medsiglip.pt --image ./ecg.png
10
+ python inference_loader.py --ckpt ./mlp_classifier_medsiglip.pt --folder ./images --out ./preds.json
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ from PIL import Image
25
+ from transformers import AutoImageProcessor, AutoModel
26
+
27
+
28
+ DEFAULT_MODEL_ID = "google/medsiglip-448"
29
+ IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
30
+
31
+
32
+ class ExpertMLP(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_dim: int,
36
+ out_dim: int,
37
+ hidden: tuple[int, ...] = (1028, 512, 256),
38
+ dropout: tuple[float, ...] = (0.15, 0.15, 0.10),
39
+ ):
40
+ super().__init__()
41
+ layers: list[nn.Module] = []
42
+ prev = in_dim
43
+
44
+ dropout_values = tuple(dropout)
45
+ if len(dropout_values) < len(hidden):
46
+ dropout_values = dropout_values + (0.0,) * (len(hidden) - len(dropout_values))
47
+ elif len(dropout_values) > len(hidden):
48
+ dropout_values = dropout_values[: len(hidden)]
49
+
50
+ for h, p in zip(hidden, dropout_values):
51
+ layers.append(nn.Linear(prev, h))
52
+ layers.append(nn.LayerNorm(h))
53
+ layers.append(nn.GELU())
54
+ layers.append(nn.Dropout(p))
55
+ prev = h
56
+
57
+ layers.append(nn.Linear(prev, out_dim))
58
+ self.net = nn.Sequential(*layers)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ return self.net(x)
62
+
63
+
64
+ class MoEClassifier(nn.Module):
65
+ def __init__(
66
+ self,
67
+ in_dim: int,
68
+ out_dim: int,
69
+ num_experts: int = 5,
70
+ gate_hidden: int = 512,
71
+ temperature: float = 1.0,
72
+ expert_hidden: tuple[int, ...] = (1028, 512, 256),
73
+ expert_dropout: tuple[float, ...] = (0.15, 0.15, 0.10),
74
+ ):
75
+ super().__init__()
76
+ self.temperature = temperature
77
+ self.experts = nn.ModuleList(
78
+ [
79
+ ExpertMLP(
80
+ in_dim=in_dim,
81
+ out_dim=out_dim,
82
+ hidden=expert_hidden,
83
+ dropout=expert_dropout,
84
+ )
85
+ for _ in range(num_experts)
86
+ ]
87
+ )
88
+ self.gate = nn.Sequential(
89
+ nn.Linear(in_dim, gate_hidden),
90
+ nn.ReLU(),
91
+ nn.Linear(gate_hidden, num_experts),
92
+ )
93
+
94
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
95
+ gate_logits = self.gate(x) / self.temperature
96
+ gate_w = torch.softmax(gate_logits, dim=-1)
97
+ expert_logits = torch.stack([expert(x) for expert in self.experts], dim=1)
98
+ mixed_logits = torch.sum(expert_logits * gate_w.unsqueeze(-1), dim=1)
99
+ return mixed_logits, gate_w, expert_logits
100
+
101
+
102
+ class MLPClassifier(nn.Module):
103
+ def __init__(self, in_dim: int, hidden_1: int, hidden_2: int, out_dim: int):
104
+ super().__init__()
105
+ self.fc1 = nn.Linear(in_dim, hidden_1)
106
+ self.fc2 = nn.Linear(hidden_1, hidden_2)
107
+ self.out = nn.Linear(hidden_2, out_dim)
108
+ self.relu = nn.ReLU()
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ x = self.relu(self.fc1(x))
112
+ x = self.relu(self.fc2(x))
113
+ return self.out(x)
114
+
115
+
116
+ def get_device(device_arg: str) -> str:
117
+ if device_arg == "auto":
118
+ return "cuda" if torch.cuda.is_available() else "cpu"
119
+ if device_arg == "cuda" and not torch.cuda.is_available():
120
+ return "cpu"
121
+ return device_arg
122
+
123
+
124
+ def collect_images(image: str | None, images: list[str] | None, folder: str | None) -> list[str]:
125
+ paths: list[str] = []
126
+ if image:
127
+ paths.append(image)
128
+ if images:
129
+ paths.extend(images)
130
+ if folder:
131
+ for name in sorted(os.listdir(folder)):
132
+ p = os.path.join(folder, name)
133
+ ext = os.path.splitext(p)[1].lower()
134
+ if os.path.isfile(p) and ext in IMAGE_EXTS:
135
+ paths.append(p)
136
+
137
+ out: list[str] = []
138
+ seen: set[str] = set()
139
+ for p in paths:
140
+ ap = str(Path(p).resolve())
141
+ if ap not in seen:
142
+ seen.add(ap)
143
+ out.append(ap)
144
+ return out
145
+
146
+
147
+ def extract_features(output: Any) -> torch.Tensor:
148
+ if isinstance(output, torch.Tensor):
149
+ return output
150
+ if hasattr(output, "pooler_output") and output.pooler_output is not None:
151
+ return output.pooler_output
152
+ if hasattr(output, "last_hidden_state") and output.last_hidden_state is not None:
153
+ return output.last_hidden_state[:, 0, :]
154
+ raise TypeError(f"Unexpected image feature output type: {type(output)}")
155
+
156
+
157
+ def build_classifier(ckpt: dict[str, Any]) -> tuple[nn.Module, str]:
158
+ state_dict = ckpt.get("state_dict")
159
+ if not isinstance(state_dict, dict) or not state_dict:
160
+ raise RuntimeError("Checkpoint missing state_dict.")
161
+
162
+ embed_dim = int(ckpt["embed_dim"])
163
+ num_classes = int(ckpt["num_classes"])
164
+
165
+ if any(key.startswith("experts.") for key in state_dict.keys()):
166
+ num_experts = int(ckpt.get("num_experts", 5))
167
+ expert_linear_layers: list[tuple[int, torch.Tensor]] = []
168
+ for key, value in state_dict.items():
169
+ if not (
170
+ key.startswith("experts.0.net.")
171
+ and key.endswith(".weight")
172
+ and isinstance(value, torch.Tensor)
173
+ and value.ndim == 2
174
+ ):
175
+ continue
176
+ layer_index = int(key.split(".")[3])
177
+ expert_linear_layers.append((layer_index, value))
178
+
179
+ if len(expert_linear_layers) < 2:
180
+ raise RuntimeError("Unable to infer expert architecture from checkpoint.")
181
+
182
+ expert_linear_layers.sort(key=lambda item: item[0])
183
+ expert_hidden = tuple(int(weight.shape[0]) for _, weight in expert_linear_layers[:-1])
184
+ gate_hidden = int(state_dict["gate.0.weight"].shape[0]) if "gate.0.weight" in state_dict else 256
185
+
186
+ model = MoEClassifier(
187
+ in_dim=embed_dim,
188
+ out_dim=num_classes,
189
+ num_experts=num_experts,
190
+ gate_hidden=gate_hidden,
191
+ temperature=1.0,
192
+ expert_hidden=expert_hidden,
193
+ expert_dropout=tuple(0.0 for _ in expert_hidden),
194
+ )
195
+ model_type = "moe"
196
+ elif {"fc1.weight", "fc2.weight", "out.weight"}.issubset(state_dict.keys()):
197
+ hidden_1 = int(state_dict["fc1.weight"].shape[0])
198
+ hidden_2 = int(state_dict["fc2.weight"].shape[0])
199
+ model = MLPClassifier(
200
+ in_dim=embed_dim,
201
+ hidden_1=hidden_1,
202
+ hidden_2=hidden_2,
203
+ out_dim=num_classes,
204
+ )
205
+ model_type = "mlp"
206
+ else:
207
+ raise RuntimeError("Unsupported checkpoint format.")
208
+
209
+ model.load_state_dict(state_dict, strict=True)
210
+ model.eval()
211
+ return model, model_type
212
+
213
+
214
+ @torch.no_grad()
215
+ def embed_images(
216
+ embedder: AutoModel,
217
+ processor: AutoImageProcessor,
218
+ image_paths: list[str],
219
+ batch_size: int,
220
+ ) -> tuple[np.ndarray, list[str]]:
221
+ embs: list[np.ndarray] = []
222
+ kept: list[str] = []
223
+ for i in range(0, len(image_paths), batch_size):
224
+ batch_paths = image_paths[i : i + batch_size]
225
+ batch_images = []
226
+ batch_kept: list[str] = []
227
+ for p in batch_paths:
228
+ try:
229
+ batch_images.append(Image.open(p).convert("RGB"))
230
+ batch_kept.append(p)
231
+ except Exception as exc: # pragma: no cover
232
+ print(f"[skip] {p}: {exc}")
233
+
234
+ if not batch_images:
235
+ continue
236
+
237
+ inputs = processor(images=batch_images, return_tensors="pt").to(embedder.device)
238
+ out = embedder.get_image_features(**inputs)
239
+ feats = extract_features(out)
240
+
241
+ embs.append(feats.detach().cpu().numpy().astype(np.float32))
242
+ kept.extend(batch_kept)
243
+
244
+ if not embs:
245
+ return np.zeros((0, 0), dtype=np.float32), []
246
+ return np.concatenate(embs, axis=0), kept
247
+
248
+
249
+ @torch.no_grad()
250
+ def predict(
251
+ classifier: nn.Module,
252
+ X_emb: np.ndarray,
253
+ device: str,
254
+ ) -> np.ndarray:
255
+ xb = torch.from_numpy(X_emb).to(device)
256
+ logits_output = classifier(xb)
257
+ logits = logits_output[0] if isinstance(logits_output, tuple) else logits_output
258
+ probs = torch.sigmoid(logits).detach().cpu().numpy()
259
+ return probs
260
+
261
+
262
+ def main() -> None:
263
+ parser = argparse.ArgumentParser()
264
+ parser.add_argument("--ckpt", type=str, required=True, help="Path to checkpoint .pt file")
265
+ parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID, help="HF model id for embedder")
266
+ parser.add_argument("--hf_token", type=str, default=None, help="HF token (or set HF_TOKEN env var)")
267
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
268
+ parser.add_argument("--batch_size", type=int, default=16)
269
+ parser.add_argument("--threshold", type=float, default=None, help="Override checkpoint threshold")
270
+
271
+ parser.add_argument("--image", type=str, default=None, help="Single image path")
272
+ parser.add_argument("--images", nargs="*", default=None, help="Multiple image paths")
273
+ parser.add_argument("--folder", type=str, default=None, help="Folder with images")
274
+ parser.add_argument("--out", type=str, default=None, help="Optional output JSON path")
275
+ args = parser.parse_args()
276
+
277
+ image_paths = collect_images(args.image, args.images, args.folder)
278
+ if not image_paths:
279
+ raise SystemExit("No images found. Use --image, --images, or --folder.")
280
+
281
+ device = get_device(args.device)
282
+ ckpt = torch.load(args.ckpt, map_location="cpu")
283
+ if not isinstance(ckpt, dict):
284
+ raise SystemExit("Checkpoint must be a dict.")
285
+
286
+ classifier, model_type = build_classifier(ckpt)
287
+ classifier.to(device)
288
+ classifier.eval()
289
+
290
+ embed_dim = int(ckpt["embed_dim"])
291
+ num_classes = int(ckpt["num_classes"])
292
+ threshold = float(args.threshold) if args.threshold is not None else float(ckpt.get("threshold", 0.5))
293
+ classes = ckpt.get("classes")
294
+ if not isinstance(classes, list) or len(classes) != num_classes:
295
+ classes = [f"class_{i}" for i in range(num_classes)]
296
+
297
+ token = args.hf_token or os.getenv("HF_TOKEN")
298
+ embedder = AutoModel.from_pretrained(args.model_id, token=token)
299
+ processor = AutoImageProcessor.from_pretrained(args.model_id, token=token)
300
+ embedder.to(device)
301
+ embedder.eval()
302
+
303
+ X_emb, kept_paths = embed_images(embedder, processor, image_paths, args.batch_size)
304
+ if X_emb.shape[0] == 0:
305
+ raise SystemExit("No images could be processed.")
306
+ if X_emb.shape[1] != embed_dim:
307
+ raise SystemExit(
308
+ f"Embedding dim mismatch: produced {X_emb.shape[1]} but checkpoint expects {embed_dim}."
309
+ )
310
+
311
+ probs = predict(classifier, X_emb, device=device)
312
+ preds = (probs >= threshold).astype(int)
313
+
314
+ summary = {
315
+ "checkpoint": str(Path(args.ckpt).resolve()),
316
+ "classifier_type": model_type,
317
+ "model_id": args.model_id,
318
+ "device": device,
319
+ "num_images": len(kept_paths),
320
+ "embed_dim": embed_dim,
321
+ "num_classes": num_classes,
322
+ "threshold": threshold,
323
+ }
324
+ if "num_experts" in ckpt:
325
+ summary["num_experts"] = int(ckpt["num_experts"])
326
+
327
+ results = []
328
+ for image_path, row_prob, row_pred in zip(kept_paths, probs, preds):
329
+ results.append(
330
+ {
331
+ "image_path": image_path,
332
+ "scores_by_class": {label: float(score) for label, score in zip(classes, row_prob)},
333
+ "predicted_labels": [label for label, y in zip(classes, row_pred) if int(y) == 1],
334
+ }
335
+ )
336
+
337
+ payload = {"summary": summary, "results": results}
338
+ print(json.dumps(summary, indent=2))
339
+ print(json.dumps(results[:3], indent=2))
340
+
341
+ if args.out:
342
+ out_path = Path(args.out).resolve()
343
+ out_path.parent.mkdir(parents=True, exist_ok=True)
344
+ out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
345
+ print(f"Saved output to: {out_path}")
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ pillow
4
+ numpy
5
+ tqdm