Add model details and instructions
Browse files- README.md +129 -3
- inference_loader.py +349 -0
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -1,3 +1,129 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ECG Image Classifier (MoE and MLP) on MedSigLIP Embeddings
|
| 2 |
+
|
| 3 |
+
This repository provides two PyTorch ECG classifier checkpoints trained on top of frozen MedSigLIP image embeddings:
|
| 4 |
+
|
| 5 |
+
- `moe_classifier_medsiglip.pt`: Mixture-of-Experts (MoE) classifier
|
| 6 |
+
- `mlp_classifier_medsiglip.pt`: Dense feedforward (MLP) classifier
|
| 7 |
+
|
| 8 |
+
These checkpoints expect embeddings produced by:
|
| 9 |
+
|
| 10 |
+
- `google/medsiglip-448`
|
| 11 |
+
|
| 12 |
+
The repository contains only the classifier heads. MedSigLIP weights are not included and must be obtained separately under Google’s license.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## Motivation
|
| 17 |
+
|
| 18 |
+
This work was developed as part of the Google MedGemma Impact Challenge:
|
| 19 |
+
https://www.kaggle.com/competitions/med-gemma-impact-challenge/overview
|
| 20 |
+
|
| 21 |
+
The goal is to build a lightweight, deployable ECG image classifier for chronic care screening, especially in low-resource clinical settings where ECG is often the most accessible diagnostic modality.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Task and Data
|
| 26 |
+
|
| 27 |
+
We formulate a supervised multi-label image classification task on 12-lead ECGs with five diagnostic categories:
|
| 28 |
+
|
| 29 |
+
- NORM (normal)
|
| 30 |
+
- MI (myocardial infarction)
|
| 31 |
+
- STTC (ST-T changes)
|
| 32 |
+
- CD (conduction disturbances)
|
| 33 |
+
- HYP (hypertrophy)
|
| 34 |
+
|
| 35 |
+
Training data combines:
|
| 36 |
+
|
| 37 |
+
- PTB-XL, a large-scale dataset of raw 12-lead ECG waveforms in WFDB format with 16-bit precision
|
| 38 |
+
- A supplementary ECG image dataset
|
| 39 |
+
|
| 40 |
+
To enable image-based classification, raw PTB-XL waveforms are converted into realistic print-style ECG images using the open-source ECG image generator by Rahimi et al. This yields approximately 21,000 synthetic ECG images, which are combined with 713 real ECG images from the supplementary dataset.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Model and Training
|
| 45 |
+
|
| 46 |
+
ECG images are first encoded using MedSigLIP to obtain fixed-dimensional visual embeddings. Two lightweight classifiers are trained on top of these embeddings:
|
| 47 |
+
|
| 48 |
+
- A dense feedforward network (MLP)
|
| 49 |
+
- A Mixture-of-Experts (MoE) classifier
|
| 50 |
+
|
| 51 |
+
The dataset is split into 60 percent training, 20 percent validation, and 20 percent testing. Both models are optimized with Adam using a learning rate of 1e-4 and weight decay of 1e-5. The MoE model additionally uses a load-balancing regularization term with lambda set to 0.1.
|
| 52 |
+
|
| 53 |
+
For multi-label prediction, a uniform decision threshold of 0.3 is applied across all classes.
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Results
|
| 58 |
+
|
| 59 |
+
On the held-out test set, the MoE classifier consistently outperforms the MLP baseline across all metrics. It achieves:
|
| 60 |
+
|
| 61 |
+
- Lower Hamming loss: 0.167 vs 0.235
|
| 62 |
+
- Higher ROC-AUC:
|
| 63 |
+
- Micro: 0.891 vs 0.827
|
| 64 |
+
- Macro: 0.879 vs 0.808
|
| 65 |
+
- Higher F1 scores:
|
| 66 |
+
- Micro: 0.70 vs 0.61
|
| 67 |
+
- Macro: 0.67 vs 0.58
|
| 68 |
+
|
| 69 |
+
Per-class F1 improves across all five diagnostic categories, with the largest gains observed for myocardial infarction and hypertrophy. Confusion matrix analysis indicates that the MLP baseline tends to trade precision for recall, producing more false positives and a lower overall F1. For this reason, the MoE classifier is used in the final application.
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## Practical Implications
|
| 74 |
+
|
| 75 |
+
Compared to using MedGemma alone, the MedSigLIP plus classifier pipeline provides more structured and reliable ECG predictions. In addition to discrete labels, the classifier outputs calibrated confidence scores. This supports threshold-based screening and triage, which is particularly useful in chronic care workflows and remote clinics where rapid ECG assessment can help prioritize referrals.
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## How to Use
|
| 80 |
+
|
| 81 |
+
### 1) Install dependencies
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
pip install -r requirements.txt
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### 2) Run inference
|
| 88 |
+
|
| 89 |
+
Single image with the MoE checkpoint:
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
python inference_loader.py \
|
| 93 |
+
--ckpt ./moe_classifier_medsiglip.pt \
|
| 94 |
+
--image ./sample_ecg.png \
|
| 95 |
+
--out ./preds_moe.json
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
Batch inference on a folder with the MLP checkpoint:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python inference_loader.py \
|
| 102 |
+
--ckpt ./mlp_classifier_medsiglip.pt \
|
| 103 |
+
--folder ./images \
|
| 104 |
+
--out ./preds_mlp.json
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### 3) Optional arguments
|
| 108 |
+
|
| 109 |
+
- `--model_id` (default: `google/medsiglip-448`)
|
| 110 |
+
- `--device auto|cpu|cuda`
|
| 111 |
+
- `--batch_size 16`
|
| 112 |
+
- `--threshold 0.3` (overrides the checkpoint threshold)
|
| 113 |
+
- `--hf_token <token>` (or set `HF_TOKEN` as an environment variable)
|
| 114 |
+
|
| 115 |
+
### 4) Outputs
|
| 116 |
+
|
| 117 |
+
The inference script returns:
|
| 118 |
+
|
| 119 |
+
- `scores_by_class`: confidence scores for each diagnostic class
|
| 120 |
+
- `predicted_labels`: labels above the decision threshold
|
| 121 |
+
- `summary`: run metadata including checkpoint, model type, device, and embedding dimensions
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## References
|
| 126 |
+
|
| 127 |
+
[1] PTB-XL dataset
|
| 128 |
+
[2] Supplementary ECG image dataset used in this project
|
| 129 |
+
[3] Rahimi et al., open-source ECG image generator
|
inference_loader.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Standalone inference loader for ECG classifier checkpoints.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- MoE checkpoints (experts.* + gate.* keys)
|
| 6 |
+
- MLP checkpoints (fc1/fc2/out keys)
|
| 7 |
+
|
| 8 |
+
Usage examples:
|
| 9 |
+
python inference_loader.py --ckpt ./moe_classifier_medsiglip.pt --image ./ecg.png
|
| 10 |
+
python inference_loader.py --ckpt ./mlp_classifier_medsiglip.pt --folder ./images --out ./preds.json
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
DEFAULT_MODEL_ID = "google/medsiglip-448"
|
| 29 |
+
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ExpertMLP(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
in_dim: int,
|
| 36 |
+
out_dim: int,
|
| 37 |
+
hidden: tuple[int, ...] = (1028, 512, 256),
|
| 38 |
+
dropout: tuple[float, ...] = (0.15, 0.15, 0.10),
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
layers: list[nn.Module] = []
|
| 42 |
+
prev = in_dim
|
| 43 |
+
|
| 44 |
+
dropout_values = tuple(dropout)
|
| 45 |
+
if len(dropout_values) < len(hidden):
|
| 46 |
+
dropout_values = dropout_values + (0.0,) * (len(hidden) - len(dropout_values))
|
| 47 |
+
elif len(dropout_values) > len(hidden):
|
| 48 |
+
dropout_values = dropout_values[: len(hidden)]
|
| 49 |
+
|
| 50 |
+
for h, p in zip(hidden, dropout_values):
|
| 51 |
+
layers.append(nn.Linear(prev, h))
|
| 52 |
+
layers.append(nn.LayerNorm(h))
|
| 53 |
+
layers.append(nn.GELU())
|
| 54 |
+
layers.append(nn.Dropout(p))
|
| 55 |
+
prev = h
|
| 56 |
+
|
| 57 |
+
layers.append(nn.Linear(prev, out_dim))
|
| 58 |
+
self.net = nn.Sequential(*layers)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
return self.net(x)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MoEClassifier(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
in_dim: int,
|
| 68 |
+
out_dim: int,
|
| 69 |
+
num_experts: int = 5,
|
| 70 |
+
gate_hidden: int = 512,
|
| 71 |
+
temperature: float = 1.0,
|
| 72 |
+
expert_hidden: tuple[int, ...] = (1028, 512, 256),
|
| 73 |
+
expert_dropout: tuple[float, ...] = (0.15, 0.15, 0.10),
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.temperature = temperature
|
| 77 |
+
self.experts = nn.ModuleList(
|
| 78 |
+
[
|
| 79 |
+
ExpertMLP(
|
| 80 |
+
in_dim=in_dim,
|
| 81 |
+
out_dim=out_dim,
|
| 82 |
+
hidden=expert_hidden,
|
| 83 |
+
dropout=expert_dropout,
|
| 84 |
+
)
|
| 85 |
+
for _ in range(num_experts)
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
self.gate = nn.Sequential(
|
| 89 |
+
nn.Linear(in_dim, gate_hidden),
|
| 90 |
+
nn.ReLU(),
|
| 91 |
+
nn.Linear(gate_hidden, num_experts),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 95 |
+
gate_logits = self.gate(x) / self.temperature
|
| 96 |
+
gate_w = torch.softmax(gate_logits, dim=-1)
|
| 97 |
+
expert_logits = torch.stack([expert(x) for expert in self.experts], dim=1)
|
| 98 |
+
mixed_logits = torch.sum(expert_logits * gate_w.unsqueeze(-1), dim=1)
|
| 99 |
+
return mixed_logits, gate_w, expert_logits
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MLPClassifier(nn.Module):
|
| 103 |
+
def __init__(self, in_dim: int, hidden_1: int, hidden_2: int, out_dim: int):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.fc1 = nn.Linear(in_dim, hidden_1)
|
| 106 |
+
self.fc2 = nn.Linear(hidden_1, hidden_2)
|
| 107 |
+
self.out = nn.Linear(hidden_2, out_dim)
|
| 108 |
+
self.relu = nn.ReLU()
|
| 109 |
+
|
| 110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
x = self.relu(self.fc1(x))
|
| 112 |
+
x = self.relu(self.fc2(x))
|
| 113 |
+
return self.out(x)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_device(device_arg: str) -> str:
|
| 117 |
+
if device_arg == "auto":
|
| 118 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 119 |
+
if device_arg == "cuda" and not torch.cuda.is_available():
|
| 120 |
+
return "cpu"
|
| 121 |
+
return device_arg
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def collect_images(image: str | None, images: list[str] | None, folder: str | None) -> list[str]:
|
| 125 |
+
paths: list[str] = []
|
| 126 |
+
if image:
|
| 127 |
+
paths.append(image)
|
| 128 |
+
if images:
|
| 129 |
+
paths.extend(images)
|
| 130 |
+
if folder:
|
| 131 |
+
for name in sorted(os.listdir(folder)):
|
| 132 |
+
p = os.path.join(folder, name)
|
| 133 |
+
ext = os.path.splitext(p)[1].lower()
|
| 134 |
+
if os.path.isfile(p) and ext in IMAGE_EXTS:
|
| 135 |
+
paths.append(p)
|
| 136 |
+
|
| 137 |
+
out: list[str] = []
|
| 138 |
+
seen: set[str] = set()
|
| 139 |
+
for p in paths:
|
| 140 |
+
ap = str(Path(p).resolve())
|
| 141 |
+
if ap not in seen:
|
| 142 |
+
seen.add(ap)
|
| 143 |
+
out.append(ap)
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def extract_features(output: Any) -> torch.Tensor:
|
| 148 |
+
if isinstance(output, torch.Tensor):
|
| 149 |
+
return output
|
| 150 |
+
if hasattr(output, "pooler_output") and output.pooler_output is not None:
|
| 151 |
+
return output.pooler_output
|
| 152 |
+
if hasattr(output, "last_hidden_state") and output.last_hidden_state is not None:
|
| 153 |
+
return output.last_hidden_state[:, 0, :]
|
| 154 |
+
raise TypeError(f"Unexpected image feature output type: {type(output)}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def build_classifier(ckpt: dict[str, Any]) -> tuple[nn.Module, str]:
|
| 158 |
+
state_dict = ckpt.get("state_dict")
|
| 159 |
+
if not isinstance(state_dict, dict) or not state_dict:
|
| 160 |
+
raise RuntimeError("Checkpoint missing state_dict.")
|
| 161 |
+
|
| 162 |
+
embed_dim = int(ckpt["embed_dim"])
|
| 163 |
+
num_classes = int(ckpt["num_classes"])
|
| 164 |
+
|
| 165 |
+
if any(key.startswith("experts.") for key in state_dict.keys()):
|
| 166 |
+
num_experts = int(ckpt.get("num_experts", 5))
|
| 167 |
+
expert_linear_layers: list[tuple[int, torch.Tensor]] = []
|
| 168 |
+
for key, value in state_dict.items():
|
| 169 |
+
if not (
|
| 170 |
+
key.startswith("experts.0.net.")
|
| 171 |
+
and key.endswith(".weight")
|
| 172 |
+
and isinstance(value, torch.Tensor)
|
| 173 |
+
and value.ndim == 2
|
| 174 |
+
):
|
| 175 |
+
continue
|
| 176 |
+
layer_index = int(key.split(".")[3])
|
| 177 |
+
expert_linear_layers.append((layer_index, value))
|
| 178 |
+
|
| 179 |
+
if len(expert_linear_layers) < 2:
|
| 180 |
+
raise RuntimeError("Unable to infer expert architecture from checkpoint.")
|
| 181 |
+
|
| 182 |
+
expert_linear_layers.sort(key=lambda item: item[0])
|
| 183 |
+
expert_hidden = tuple(int(weight.shape[0]) for _, weight in expert_linear_layers[:-1])
|
| 184 |
+
gate_hidden = int(state_dict["gate.0.weight"].shape[0]) if "gate.0.weight" in state_dict else 256
|
| 185 |
+
|
| 186 |
+
model = MoEClassifier(
|
| 187 |
+
in_dim=embed_dim,
|
| 188 |
+
out_dim=num_classes,
|
| 189 |
+
num_experts=num_experts,
|
| 190 |
+
gate_hidden=gate_hidden,
|
| 191 |
+
temperature=1.0,
|
| 192 |
+
expert_hidden=expert_hidden,
|
| 193 |
+
expert_dropout=tuple(0.0 for _ in expert_hidden),
|
| 194 |
+
)
|
| 195 |
+
model_type = "moe"
|
| 196 |
+
elif {"fc1.weight", "fc2.weight", "out.weight"}.issubset(state_dict.keys()):
|
| 197 |
+
hidden_1 = int(state_dict["fc1.weight"].shape[0])
|
| 198 |
+
hidden_2 = int(state_dict["fc2.weight"].shape[0])
|
| 199 |
+
model = MLPClassifier(
|
| 200 |
+
in_dim=embed_dim,
|
| 201 |
+
hidden_1=hidden_1,
|
| 202 |
+
hidden_2=hidden_2,
|
| 203 |
+
out_dim=num_classes,
|
| 204 |
+
)
|
| 205 |
+
model_type = "mlp"
|
| 206 |
+
else:
|
| 207 |
+
raise RuntimeError("Unsupported checkpoint format.")
|
| 208 |
+
|
| 209 |
+
model.load_state_dict(state_dict, strict=True)
|
| 210 |
+
model.eval()
|
| 211 |
+
return model, model_type
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@torch.no_grad()
|
| 215 |
+
def embed_images(
|
| 216 |
+
embedder: AutoModel,
|
| 217 |
+
processor: AutoImageProcessor,
|
| 218 |
+
image_paths: list[str],
|
| 219 |
+
batch_size: int,
|
| 220 |
+
) -> tuple[np.ndarray, list[str]]:
|
| 221 |
+
embs: list[np.ndarray] = []
|
| 222 |
+
kept: list[str] = []
|
| 223 |
+
for i in range(0, len(image_paths), batch_size):
|
| 224 |
+
batch_paths = image_paths[i : i + batch_size]
|
| 225 |
+
batch_images = []
|
| 226 |
+
batch_kept: list[str] = []
|
| 227 |
+
for p in batch_paths:
|
| 228 |
+
try:
|
| 229 |
+
batch_images.append(Image.open(p).convert("RGB"))
|
| 230 |
+
batch_kept.append(p)
|
| 231 |
+
except Exception as exc: # pragma: no cover
|
| 232 |
+
print(f"[skip] {p}: {exc}")
|
| 233 |
+
|
| 234 |
+
if not batch_images:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
inputs = processor(images=batch_images, return_tensors="pt").to(embedder.device)
|
| 238 |
+
out = embedder.get_image_features(**inputs)
|
| 239 |
+
feats = extract_features(out)
|
| 240 |
+
|
| 241 |
+
embs.append(feats.detach().cpu().numpy().astype(np.float32))
|
| 242 |
+
kept.extend(batch_kept)
|
| 243 |
+
|
| 244 |
+
if not embs:
|
| 245 |
+
return np.zeros((0, 0), dtype=np.float32), []
|
| 246 |
+
return np.concatenate(embs, axis=0), kept
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def predict(
|
| 251 |
+
classifier: nn.Module,
|
| 252 |
+
X_emb: np.ndarray,
|
| 253 |
+
device: str,
|
| 254 |
+
) -> np.ndarray:
|
| 255 |
+
xb = torch.from_numpy(X_emb).to(device)
|
| 256 |
+
logits_output = classifier(xb)
|
| 257 |
+
logits = logits_output[0] if isinstance(logits_output, tuple) else logits_output
|
| 258 |
+
probs = torch.sigmoid(logits).detach().cpu().numpy()
|
| 259 |
+
return probs
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def main() -> None:
|
| 263 |
+
parser = argparse.ArgumentParser()
|
| 264 |
+
parser.add_argument("--ckpt", type=str, required=True, help="Path to checkpoint .pt file")
|
| 265 |
+
parser.add_argument("--model_id", type=str, default=DEFAULT_MODEL_ID, help="HF model id for embedder")
|
| 266 |
+
parser.add_argument("--hf_token", type=str, default=None, help="HF token (or set HF_TOKEN env var)")
|
| 267 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
|
| 268 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 269 |
+
parser.add_argument("--threshold", type=float, default=None, help="Override checkpoint threshold")
|
| 270 |
+
|
| 271 |
+
parser.add_argument("--image", type=str, default=None, help="Single image path")
|
| 272 |
+
parser.add_argument("--images", nargs="*", default=None, help="Multiple image paths")
|
| 273 |
+
parser.add_argument("--folder", type=str, default=None, help="Folder with images")
|
| 274 |
+
parser.add_argument("--out", type=str, default=None, help="Optional output JSON path")
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
image_paths = collect_images(args.image, args.images, args.folder)
|
| 278 |
+
if not image_paths:
|
| 279 |
+
raise SystemExit("No images found. Use --image, --images, or --folder.")
|
| 280 |
+
|
| 281 |
+
device = get_device(args.device)
|
| 282 |
+
ckpt = torch.load(args.ckpt, map_location="cpu")
|
| 283 |
+
if not isinstance(ckpt, dict):
|
| 284 |
+
raise SystemExit("Checkpoint must be a dict.")
|
| 285 |
+
|
| 286 |
+
classifier, model_type = build_classifier(ckpt)
|
| 287 |
+
classifier.to(device)
|
| 288 |
+
classifier.eval()
|
| 289 |
+
|
| 290 |
+
embed_dim = int(ckpt["embed_dim"])
|
| 291 |
+
num_classes = int(ckpt["num_classes"])
|
| 292 |
+
threshold = float(args.threshold) if args.threshold is not None else float(ckpt.get("threshold", 0.5))
|
| 293 |
+
classes = ckpt.get("classes")
|
| 294 |
+
if not isinstance(classes, list) or len(classes) != num_classes:
|
| 295 |
+
classes = [f"class_{i}" for i in range(num_classes)]
|
| 296 |
+
|
| 297 |
+
token = args.hf_token or os.getenv("HF_TOKEN")
|
| 298 |
+
embedder = AutoModel.from_pretrained(args.model_id, token=token)
|
| 299 |
+
processor = AutoImageProcessor.from_pretrained(args.model_id, token=token)
|
| 300 |
+
embedder.to(device)
|
| 301 |
+
embedder.eval()
|
| 302 |
+
|
| 303 |
+
X_emb, kept_paths = embed_images(embedder, processor, image_paths, args.batch_size)
|
| 304 |
+
if X_emb.shape[0] == 0:
|
| 305 |
+
raise SystemExit("No images could be processed.")
|
| 306 |
+
if X_emb.shape[1] != embed_dim:
|
| 307 |
+
raise SystemExit(
|
| 308 |
+
f"Embedding dim mismatch: produced {X_emb.shape[1]} but checkpoint expects {embed_dim}."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
probs = predict(classifier, X_emb, device=device)
|
| 312 |
+
preds = (probs >= threshold).astype(int)
|
| 313 |
+
|
| 314 |
+
summary = {
|
| 315 |
+
"checkpoint": str(Path(args.ckpt).resolve()),
|
| 316 |
+
"classifier_type": model_type,
|
| 317 |
+
"model_id": args.model_id,
|
| 318 |
+
"device": device,
|
| 319 |
+
"num_images": len(kept_paths),
|
| 320 |
+
"embed_dim": embed_dim,
|
| 321 |
+
"num_classes": num_classes,
|
| 322 |
+
"threshold": threshold,
|
| 323 |
+
}
|
| 324 |
+
if "num_experts" in ckpt:
|
| 325 |
+
summary["num_experts"] = int(ckpt["num_experts"])
|
| 326 |
+
|
| 327 |
+
results = []
|
| 328 |
+
for image_path, row_prob, row_pred in zip(kept_paths, probs, preds):
|
| 329 |
+
results.append(
|
| 330 |
+
{
|
| 331 |
+
"image_path": image_path,
|
| 332 |
+
"scores_by_class": {label: float(score) for label, score in zip(classes, row_prob)},
|
| 333 |
+
"predicted_labels": [label for label, y in zip(classes, row_pred) if int(y) == 1],
|
| 334 |
+
}
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
payload = {"summary": summary, "results": results}
|
| 338 |
+
print(json.dumps(summary, indent=2))
|
| 339 |
+
print(json.dumps(results[:3], indent=2))
|
| 340 |
+
|
| 341 |
+
if args.out:
|
| 342 |
+
out_path = Path(args.out).resolve()
|
| 343 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 344 |
+
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 345 |
+
print(f"Saved output to: {out_path}")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
pillow
|
| 4 |
+
numpy
|
| 5 |
+
tqdm
|