Yasmine97 commited on
Commit
6cb7230
·
verified ·
1 Parent(s): 2511d1f

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +247 -0
model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ from __future__ import annotations
4
+ import json
5
+ from dataclasses import dataclass, asdict
6
+ from pathlib import Path
7
+ from typing import Dict, List, Tuple, Optional
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import nibabel as nib
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import open_clip # pip install open_clip_torch
17
+
18
+
19
+ # -----------------------------
20
+ # Constants (match your training)
21
+ # -----------------------------
22
+ REPO_ID = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
23
+
24
+ # You trained with "Dementia" as class-3 name (not "AD")
25
+ LABEL2IDX: Dict[str, int] = {"CN": 0, "MCI": 1, "Dementia": 2}
26
+ IDX2LABEL: Dict[int, str] = {v: k for k, v in LABEL2IDX.items()}
27
+
28
+
29
+ # -----------------------------
30
+ # Small config to save with model
31
+ # -----------------------------
32
+ @dataclass
33
+ class ModelConfig:
34
+ model_id: str = REPO_ID
35
+ num_classes: int = 3
36
+ proj_dim: int = 512
37
+ freeze_encoders: bool = False
38
+ label2idx: Dict[str, int] = None
39
+
40
+ def to_json(self) -> str:
41
+ d = asdict(self)
42
+ return json.dumps(d, indent=2)
43
+
44
+ @staticmethod
45
+ def from_json(path: str | Path) -> "ModelConfig":
46
+ data = json.loads(Path(path).read_text())
47
+ return ModelConfig(**data)
48
+
49
+
50
+ # -----------------------------
51
+ # 3D→2D Triptych utilities
52
+ # -----------------------------
53
+ def center_crop_or_pad(vol: np.ndarray, target_shape: Tuple[int, int, int]) -> np.ndarray:
54
+ """Center-crop or zero-pad a 3D volume to target_shape=(D,H,W)."""
55
+ D, H, W = vol.shape
56
+ tD, tH, tW = target_shape
57
+ out = np.zeros(target_shape, dtype=vol.dtype)
58
+
59
+ d0 = max(0, (D - tD) // 2); d1 = d0 + min(D, tD)
60
+ h0 = max(0, (H - tH) // 2); h1 = h0 + min(H, tH)
61
+ w0 = max(0, (W - tW) // 2); w1 = w0 + min(W, tW)
62
+
63
+ td0 = max(0, (tD - D) // 2); td1 = td0 + (d1 - d0)
64
+ th0 = max(0, (tH - H) // 2); th1 = th0 + (h1 - h0)
65
+ tw0 = max(0, (tW - W) // 2); tw1 = tw0 + (w1 - w0)
66
+
67
+ out[td0:td1, th0:th1, tw0:tw1] = vol[d0:d1, h0:h1, w0:w1]
68
+ return out
69
+
70
+
71
+ def volume_to_triptych(volume_1d: torch.Tensor, out_size: int = 224) -> Image.Image:
72
+ """
73
+ volume_1d: torch tensor [1, D, H, W] in [0,1].
74
+ Returns a PIL RGB image (triptych of axial/coronal/sagittal mid-slices).
75
+ """
76
+ assert volume_1d.ndim == 4 and volume_1d.shape[0] == 1
77
+ _, D, H, W = volume_1d.shape
78
+ v = volume_1d[0].cpu().numpy() # [D,H,W]
79
+
80
+ d_mid, h_mid, w_mid = D // 2, H // 2, W // 2
81
+ axial = v[d_mid, :, :] # [H,W]
82
+ coronal = v[:, h_mid, :] # [D,W] -> resize to [H,W]
83
+ sagitt = v[:, :, w_mid] # [D,H] -> resize to [H,W]
84
+
85
+ def norm_to_uint8(x: np.ndarray) -> np.ndarray:
86
+ x = (x - x.min()) / (x.max() - x.min() + 1e-8)
87
+ return (x * 255.0).astype(np.uint8)
88
+
89
+ axial_img = Image.fromarray(norm_to_uint8(axial))
90
+ coronal_img = Image.fromarray(norm_to_uint8(coronal)).resize((W, H), Image.BILINEAR)
91
+ sagitt_img = Image.fromarray(norm_to_uint8(sagitt)).resize((W, H), Image.BILINEAR)
92
+
93
+ rgb = np.stack([np.array(axial_img), np.array(coronal_img), np.array(sagitt_img)], axis=-1)
94
+ pil = Image.fromarray(rgb.astype(np.uint8)).resize((out_size, out_size), Image.BILINEAR)
95
+ return pil
96
+
97
+
98
+ # -----------------------------
99
+ # The model (same as training)
100
+ # -----------------------------
101
+ class BiomedClipClassifier(nn.Module):
102
+ """
103
+ Encodes MRI triptych (image) + clinical text with BiomedCLIP (open_clip),
104
+ concatenates L2-normalized embeddings, then classifies into 3 classes.
105
+ """
106
+ def __init__(
107
+ self,
108
+ model_id: str = REPO_ID,
109
+ num_classes: int = 3,
110
+ proj_dim: int = 512,
111
+ freeze_encoders: bool = False,
112
+ device: str = "cpu",
113
+ ):
114
+ super().__init__()
115
+ # Load CLIP model & transforms
116
+ self.clip, self.preprocess_train, self.preprocess_val = open_clip.create_model_and_transforms(model_id)
117
+ self.tokenizer_fn = open_clip.get_tokenizer(model_id)
118
+ self.clip.to(device)
119
+
120
+ if freeze_encoders:
121
+ for p in self.clip.parameters():
122
+ p.requires_grad = False
123
+
124
+ # Infer feature dims
125
+ with torch.no_grad():
126
+ dummy_img = torch.zeros(1, 3, 224, 224, device=device)
127
+ dummy_txt = self.tokenizer_fn(["test"]).to(device)
128
+ dim_i = self.clip.encode_image(dummy_img).shape[-1]
129
+ dim_t = self.clip.encode_text(dummy_txt).shape[-1]
130
+
131
+ in_dim = dim_i + dim_t
132
+ self.head = nn.Sequential(
133
+ nn.Linear(in_dim, proj_dim),
134
+ nn.ReLU(),
135
+ nn.Dropout(0.2),
136
+ nn.Linear(proj_dim, num_classes),
137
+ )
138
+
139
+ def forward(self, images: torch.Tensor, texts_tok: torch.Tensor) -> torch.Tensor:
140
+ img_f = F.normalize(self.clip.encode_image(images), dim=-1)
141
+ txt_f = F.normalize(self.clip.encode_text(texts_tok), dim=-1)
142
+ return self.head(torch.cat([img_f, txt_f], dim=-1))
143
+
144
+ # ------------- HF-style save/load -------------
145
+ def save_pretrained(self, save_directory: str | Path, config: Optional[ModelConfig] = None):
146
+ save_dir = Path(save_directory)
147
+ save_dir.mkdir(parents=True, exist_ok=True)
148
+ # state dict
149
+ torch.save(self.state_dict(), save_dir / "pytorch_model.bin")
150
+ # minimal config
151
+ if config is None:
152
+ config = ModelConfig(label2idx=LABEL2IDX)
153
+ (save_dir / "config.json").write_text(config.to_json())
154
+
155
+ @staticmethod
156
+ def from_pretrained(load_directory: str | Path, device: str = "cpu") -> "BiomedClipClassifier":
157
+ load_dir = Path(load_directory)
158
+ cfg_path = load_dir / "config.json"
159
+ state_path = load_dir / "pytorch_model.bin"
160
+
161
+ if cfg_path.exists():
162
+ cfg = ModelConfig.from_json(cfg_path)
163
+ else:
164
+ # fallback if only a state dict is present
165
+ cfg = ModelConfig(label2idx=LABEL2IDX)
166
+
167
+ model = BiomedClipClassifier(
168
+ model_id=cfg.model_id,
169
+ num_classes=cfg.num_classes,
170
+ proj_dim=cfg.proj_dim,
171
+ freeze_encoders=cfg.freeze_encoders,
172
+ device=device,
173
+ )
174
+ if state_path.exists():
175
+ state = torch.load(state_path, map_location=device)
176
+ model.load_state_dict(state, strict=False)
177
+ else:
178
+ # Also allow people to pass a raw .pt file path as directory
179
+ # e.g., repo contains 'biomedclip_best.pt'
180
+ pt_fallback = next(load_dir.glob("*.pt"), None)
181
+ if pt_fallback is not None:
182
+ state = torch.load(pt_fallback, map_location=device)
183
+ model.load_state_dict(state, strict=False)
184
+
185
+ model.eval()
186
+ return model
187
+
188
+
189
+ # -----------------------------
190
+ # Simple single-sample inference helpers
191
+ # -----------------------------
192
+ @torch.no_grad()
193
+ def predict_from_paths(
194
+ model: BiomedClipClassifier,
195
+ mri_path: str | Path,
196
+ text: str,
197
+ device: str = "cpu",
198
+ use_val_preprocess: bool = True,
199
+ target_shape: Tuple[int, int, int] = (128, 128, 128),
200
+ ) -> Tuple[str, List[float]]:
201
+ """
202
+ Convenience function to run inference on one NIfTI + text string.
203
+ Returns (pred_label, class_probs).
204
+ """
205
+ model.eval()
206
+ mri_path = Path(mri_path)
207
+
208
+ # Load & normalize volume
209
+ vol = nib.load(str(mri_path)).get_fdata().astype(np.float32)
210
+ v = (vol - vol.mean()) / (vol.std() + 1e-8)
211
+ v = (v - v.min()) / (v.max() - v.min() + 1e-8)
212
+ v = center_crop_or_pad(v, target_shape)
213
+
214
+ # Triptych -> preprocess
215
+ img_t = torch.from_numpy(v).unsqueeze(0) # [1,D,H,W]
216
+ trip_pil = volume_to_triptych(img_t) # PIL RGB 224x224
217
+ preprocess = model.preprocess_val if use_val_preprocess else model.preprocess_train
218
+ img_clip = preprocess(trip_pil).unsqueeze(0).to(device)
219
+
220
+ # Tokenize text
221
+ tokenizer = model.tokenizer_fn
222
+ txt_tok = tokenizer([text]).to(device)
223
+
224
+ # Forward
225
+ logits = model(img_clip, txt_tok)
226
+ probs = torch.softmax(logits, dim=-1)[0].cpu().tolist()
227
+ pred_idx = int(torch.argmax(logits, dim=-1).item())
228
+ pred_label = IDX2LABEL[pred_idx]
229
+ return pred_label, probs
230
+
231
+
232
+ # -----------------------------
233
+ # Minimal example (optional)
234
+ # -----------------------------
235
+ if __name__ == "__main__":
236
+ # Example: load a local folder with 'pytorch_model.bin' (or a .pt) and run one inference.
237
+ # Set paths before running.
238
+ weights_dir = "./" # folder containing pytorch_model.bin or a *.pt
239
+ nifti_path = "/path/to/sample_brain.nii.gz"
240
+ text_input = "Patient shows mild memory impairment and hippocampal atrophy."
241
+
242
+ device = "cuda" if torch.cuda.is_available() else "cpu"
243
+ model = BiomedClipClassifier.from_pretrained(weights_dir, device=device)
244
+
245
+ pred, probs = predict_from_paths(model, nifti_path, text_input, device=device)
246
+ print("Prediction:", pred)
247
+ print("Probabilities [CN, MCI, Dementia]:", [round(p, 4) for p in probs])