# model.py from __future__ import annotations import json from dataclasses import dataclass, asdict from pathlib import Path from typing import Dict, List, Tuple, Optional import numpy as np from PIL import Image import nibabel as nib import torch import torch.nn as nn import torch.nn.functional as F import open_clip # pip install open_clip_torch # ----------------------------- # Constants (match your training) # ----------------------------- REPO_ID = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" # You trained with "Dementia" as class-3 name (not "AD") LABEL2IDX: Dict[str, int] = {"CN": 0, "MCI": 1, "Dementia": 2} IDX2LABEL: Dict[int, str] = {v: k for k, v in LABEL2IDX.items()} # ----------------------------- # Small config to save with model # ----------------------------- @dataclass class ModelConfig: model_id: str = REPO_ID num_classes: int = 3 proj_dim: int = 512 freeze_encoders: bool = False label2idx: Dict[str, int] = None def to_json(self) -> str: d = asdict(self) return json.dumps(d, indent=2) @staticmethod def from_json(path: str | Path) -> "ModelConfig": data = json.loads(Path(path).read_text()) return ModelConfig(**data) # ----------------------------- # 3D→2D Triptych utilities # ----------------------------- def center_crop_or_pad(vol: np.ndarray, target_shape: Tuple[int, int, int]) -> np.ndarray: """Center-crop or zero-pad a 3D volume to target_shape=(D,H,W).""" D, H, W = vol.shape tD, tH, tW = target_shape out = np.zeros(target_shape, dtype=vol.dtype) d0 = max(0, (D - tD) // 2); d1 = d0 + min(D, tD) h0 = max(0, (H - tH) // 2); h1 = h0 + min(H, tH) w0 = max(0, (W - tW) // 2); w1 = w0 + min(W, tW) td0 = max(0, (tD - D) // 2); td1 = td0 + (d1 - d0) th0 = max(0, (tH - H) // 2); th1 = th0 + (h1 - h0) tw0 = max(0, (tW - W) // 2); tw1 = tw0 + (w1 - w0) out[td0:td1, th0:th1, tw0:tw1] = vol[d0:d1, h0:h1, w0:w1] return out def volume_to_triptych(volume_1d: torch.Tensor, out_size: int = 224) -> Image.Image: """ volume_1d: torch tensor [1, D, H, W] in [0,1]. Returns a PIL RGB image (triptych of axial/coronal/sagittal mid-slices). """ assert volume_1d.ndim == 4 and volume_1d.shape[0] == 1 _, D, H, W = volume_1d.shape v = volume_1d[0].cpu().numpy() # [D,H,W] d_mid, h_mid, w_mid = D // 2, H // 2, W // 2 axial = v[d_mid, :, :] # [H,W] coronal = v[:, h_mid, :] # [D,W] -> resize to [H,W] sagitt = v[:, :, w_mid] # [D,H] -> resize to [H,W] def norm_to_uint8(x: np.ndarray) -> np.ndarray: x = (x - x.min()) / (x.max() - x.min() + 1e-8) return (x * 255.0).astype(np.uint8) axial_img = Image.fromarray(norm_to_uint8(axial)) coronal_img = Image.fromarray(norm_to_uint8(coronal)).resize((W, H), Image.BILINEAR) sagitt_img = Image.fromarray(norm_to_uint8(sagitt)).resize((W, H), Image.BILINEAR) rgb = np.stack([np.array(axial_img), np.array(coronal_img), np.array(sagitt_img)], axis=-1) pil = Image.fromarray(rgb.astype(np.uint8)).resize((out_size, out_size), Image.BILINEAR) return pil # ----------------------------- # The model (same as training) # ----------------------------- class BiomedClipClassifier(nn.Module): """ Encodes MRI triptych (image) + clinical text with BiomedCLIP (open_clip), concatenates L2-normalized embeddings, then classifies into 3 classes. """ def __init__( self, model_id: str = REPO_ID, num_classes: int = 3, proj_dim: int = 512, freeze_encoders: bool = False, device: str = "cpu", ): super().__init__() # Load CLIP model & transforms self.clip, self.preprocess_train, self.preprocess_val = open_clip.create_model_and_transforms(model_id) self.tokenizer_fn = open_clip.get_tokenizer(model_id) self.clip.to(device) if freeze_encoders: for p in self.clip.parameters(): p.requires_grad = False # Infer feature dims with torch.no_grad(): dummy_img = torch.zeros(1, 3, 224, 224, device=device) dummy_txt = self.tokenizer_fn(["test"]).to(device) dim_i = self.clip.encode_image(dummy_img).shape[-1] dim_t = self.clip.encode_text(dummy_txt).shape[-1] in_dim = dim_i + dim_t self.head = nn.Sequential( nn.Linear(in_dim, proj_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(proj_dim, num_classes), ) def forward(self, images: torch.Tensor, texts_tok: torch.Tensor) -> torch.Tensor: img_f = F.normalize(self.clip.encode_image(images), dim=-1) txt_f = F.normalize(self.clip.encode_text(texts_tok), dim=-1) return self.head(torch.cat([img_f, txt_f], dim=-1)) # ------------- HF-style save/load ------------- def save_pretrained(self, save_directory: str | Path, config: Optional[ModelConfig] = None): save_dir = Path(save_directory) save_dir.mkdir(parents=True, exist_ok=True) # state dict torch.save(self.state_dict(), save_dir / "pytorch_model.bin") # minimal config if config is None: config = ModelConfig(label2idx=LABEL2IDX) (save_dir / "config.json").write_text(config.to_json()) @staticmethod def from_pretrained(load_directory: str | Path, device: str = "cpu") -> "BiomedClipClassifier": load_dir = Path(load_directory) cfg_path = load_dir / "config.json" state_path = load_dir / "pytorch_model.bin" if cfg_path.exists(): cfg = ModelConfig.from_json(cfg_path) else: # fallback if only a state dict is present cfg = ModelConfig(label2idx=LABEL2IDX) model = BiomedClipClassifier( model_id=cfg.model_id, num_classes=cfg.num_classes, proj_dim=cfg.proj_dim, freeze_encoders=cfg.freeze_encoders, device=device, ) if state_path.exists(): state = torch.load(state_path, map_location=device) model.load_state_dict(state, strict=False) else: # Also allow people to pass a raw .pt file path as directory # e.g., repo contains 'biomedclip_best.pt' pt_fallback = next(load_dir.glob("*.pt"), None) if pt_fallback is not None: state = torch.load(pt_fallback, map_location=device) model.load_state_dict(state, strict=False) model.eval() return model # ----------------------------- # Simple single-sample inference helpers # ----------------------------- @torch.no_grad() def predict_from_paths( model: BiomedClipClassifier, mri_path: str | Path, text: str, device: str = "cpu", use_val_preprocess: bool = True, target_shape: Tuple[int, int, int] = (128, 128, 128), ) -> Tuple[str, List[float]]: """ Convenience function to run inference on one NIfTI + text string. Returns (pred_label, class_probs). """ model.eval() mri_path = Path(mri_path) # Load & normalize volume vol = nib.load(str(mri_path)).get_fdata().astype(np.float32) v = (vol - vol.mean()) / (vol.std() + 1e-8) v = (v - v.min()) / (v.max() - v.min() + 1e-8) v = center_crop_or_pad(v, target_shape) # Triptych -> preprocess img_t = torch.from_numpy(v).unsqueeze(0) # [1,D,H,W] trip_pil = volume_to_triptych(img_t) # PIL RGB 224x224 preprocess = model.preprocess_val if use_val_preprocess else model.preprocess_train img_clip = preprocess(trip_pil).unsqueeze(0).to(device) # Tokenize text tokenizer = model.tokenizer_fn txt_tok = tokenizer([text]).to(device) # Forward logits = model(img_clip, txt_tok) probs = torch.softmax(logits, dim=-1)[0].cpu().tolist() pred_idx = int(torch.argmax(logits, dim=-1).item()) pred_label = IDX2LABEL[pred_idx] return pred_label, probs # ----------------------------- # Minimal example (optional) # ----------------------------- if __name__ == "__main__": # Example: load a local folder with 'pytorch_model.bin' (or a .pt) and run one inference. # Set paths before running. weights_dir = "./" # folder containing pytorch_model.bin or a *.pt nifti_path = "/path/to/sample_brain.nii.gz" text_input = "Patient shows mild memory impairment and hippocampal atrophy." device = "cuda" if torch.cuda.is_available() else "cpu" model = BiomedClipClassifier.from_pretrained(weights_dir, device=device) pred, probs = predict_from_paths(model, nifti_path, text_input, device=device) print("Prediction:", pred) print("Probabilities [CN, MCI, Dementia]:", [round(p, 4) for p in probs])