|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
import json
|
|
|
from dataclasses import dataclass, asdict
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
import nibabel as nib
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import open_clip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
|
|
|
|
|
|
|
|
LABEL2IDX: Dict[str, int] = {"CN": 0, "MCI": 1, "Dementia": 2}
|
|
|
IDX2LABEL: Dict[int, str] = {v: k for k, v in LABEL2IDX.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
model_id: str = REPO_ID
|
|
|
num_classes: int = 3
|
|
|
proj_dim: int = 512
|
|
|
freeze_encoders: bool = False
|
|
|
label2idx: Dict[str, int] = None
|
|
|
|
|
|
def to_json(self) -> str:
|
|
|
d = asdict(self)
|
|
|
return json.dumps(d, indent=2)
|
|
|
|
|
|
@staticmethod
|
|
|
def from_json(path: str | Path) -> "ModelConfig":
|
|
|
data = json.loads(Path(path).read_text())
|
|
|
return ModelConfig(**data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def center_crop_or_pad(vol: np.ndarray, target_shape: Tuple[int, int, int]) -> np.ndarray:
|
|
|
"""Center-crop or zero-pad a 3D volume to target_shape=(D,H,W)."""
|
|
|
D, H, W = vol.shape
|
|
|
tD, tH, tW = target_shape
|
|
|
out = np.zeros(target_shape, dtype=vol.dtype)
|
|
|
|
|
|
d0 = max(0, (D - tD) // 2); d1 = d0 + min(D, tD)
|
|
|
h0 = max(0, (H - tH) // 2); h1 = h0 + min(H, tH)
|
|
|
w0 = max(0, (W - tW) // 2); w1 = w0 + min(W, tW)
|
|
|
|
|
|
td0 = max(0, (tD - D) // 2); td1 = td0 + (d1 - d0)
|
|
|
th0 = max(0, (tH - H) // 2); th1 = th0 + (h1 - h0)
|
|
|
tw0 = max(0, (tW - W) // 2); tw1 = tw0 + (w1 - w0)
|
|
|
|
|
|
out[td0:td1, th0:th1, tw0:tw1] = vol[d0:d1, h0:h1, w0:w1]
|
|
|
return out
|
|
|
|
|
|
|
|
|
def volume_to_triptych(volume_1d: torch.Tensor, out_size: int = 224) -> Image.Image:
|
|
|
"""
|
|
|
volume_1d: torch tensor [1, D, H, W] in [0,1].
|
|
|
Returns a PIL RGB image (triptych of axial/coronal/sagittal mid-slices).
|
|
|
"""
|
|
|
assert volume_1d.ndim == 4 and volume_1d.shape[0] == 1
|
|
|
_, D, H, W = volume_1d.shape
|
|
|
v = volume_1d[0].cpu().numpy()
|
|
|
|
|
|
d_mid, h_mid, w_mid = D // 2, H // 2, W // 2
|
|
|
axial = v[d_mid, :, :]
|
|
|
coronal = v[:, h_mid, :]
|
|
|
sagitt = v[:, :, w_mid]
|
|
|
|
|
|
def norm_to_uint8(x: np.ndarray) -> np.ndarray:
|
|
|
x = (x - x.min()) / (x.max() - x.min() + 1e-8)
|
|
|
return (x * 255.0).astype(np.uint8)
|
|
|
|
|
|
axial_img = Image.fromarray(norm_to_uint8(axial))
|
|
|
coronal_img = Image.fromarray(norm_to_uint8(coronal)).resize((W, H), Image.BILINEAR)
|
|
|
sagitt_img = Image.fromarray(norm_to_uint8(sagitt)).resize((W, H), Image.BILINEAR)
|
|
|
|
|
|
rgb = np.stack([np.array(axial_img), np.array(coronal_img), np.array(sagitt_img)], axis=-1)
|
|
|
pil = Image.fromarray(rgb.astype(np.uint8)).resize((out_size, out_size), Image.BILINEAR)
|
|
|
return pil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BiomedClipClassifier(nn.Module):
|
|
|
"""
|
|
|
Encodes MRI triptych (image) + clinical text with BiomedCLIP (open_clip),
|
|
|
concatenates L2-normalized embeddings, then classifies into 3 classes.
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
model_id: str = REPO_ID,
|
|
|
num_classes: int = 3,
|
|
|
proj_dim: int = 512,
|
|
|
freeze_encoders: bool = False,
|
|
|
device: str = "cpu",
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.clip, self.preprocess_train, self.preprocess_val = open_clip.create_model_and_transforms(model_id)
|
|
|
self.tokenizer_fn = open_clip.get_tokenizer(model_id)
|
|
|
self.clip.to(device)
|
|
|
|
|
|
if freeze_encoders:
|
|
|
for p in self.clip.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
dummy_img = torch.zeros(1, 3, 224, 224, device=device)
|
|
|
dummy_txt = self.tokenizer_fn(["test"]).to(device)
|
|
|
dim_i = self.clip.encode_image(dummy_img).shape[-1]
|
|
|
dim_t = self.clip.encode_text(dummy_txt).shape[-1]
|
|
|
|
|
|
in_dim = dim_i + dim_t
|
|
|
self.head = nn.Sequential(
|
|
|
nn.Linear(in_dim, proj_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(proj_dim, num_classes),
|
|
|
)
|
|
|
|
|
|
def forward(self, images: torch.Tensor, texts_tok: torch.Tensor) -> torch.Tensor:
|
|
|
img_f = F.normalize(self.clip.encode_image(images), dim=-1)
|
|
|
txt_f = F.normalize(self.clip.encode_text(texts_tok), dim=-1)
|
|
|
return self.head(torch.cat([img_f, txt_f], dim=-1))
|
|
|
|
|
|
|
|
|
def save_pretrained(self, save_directory: str | Path, config: Optional[ModelConfig] = None):
|
|
|
save_dir = Path(save_directory)
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
torch.save(self.state_dict(), save_dir / "pytorch_model.bin")
|
|
|
|
|
|
if config is None:
|
|
|
config = ModelConfig(label2idx=LABEL2IDX)
|
|
|
(save_dir / "config.json").write_text(config.to_json())
|
|
|
|
|
|
@staticmethod
|
|
|
def from_pretrained(load_directory: str | Path, device: str = "cpu") -> "BiomedClipClassifier":
|
|
|
load_dir = Path(load_directory)
|
|
|
cfg_path = load_dir / "config.json"
|
|
|
state_path = load_dir / "pytorch_model.bin"
|
|
|
|
|
|
if cfg_path.exists():
|
|
|
cfg = ModelConfig.from_json(cfg_path)
|
|
|
else:
|
|
|
|
|
|
cfg = ModelConfig(label2idx=LABEL2IDX)
|
|
|
|
|
|
model = BiomedClipClassifier(
|
|
|
model_id=cfg.model_id,
|
|
|
num_classes=cfg.num_classes,
|
|
|
proj_dim=cfg.proj_dim,
|
|
|
freeze_encoders=cfg.freeze_encoders,
|
|
|
device=device,
|
|
|
)
|
|
|
if state_path.exists():
|
|
|
state = torch.load(state_path, map_location=device)
|
|
|
model.load_state_dict(state, strict=False)
|
|
|
else:
|
|
|
|
|
|
|
|
|
pt_fallback = next(load_dir.glob("*.pt"), None)
|
|
|
if pt_fallback is not None:
|
|
|
state = torch.load(pt_fallback, map_location=device)
|
|
|
model.load_state_dict(state, strict=False)
|
|
|
|
|
|
model.eval()
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def predict_from_paths(
|
|
|
model: BiomedClipClassifier,
|
|
|
mri_path: str | Path,
|
|
|
text: str,
|
|
|
device: str = "cpu",
|
|
|
use_val_preprocess: bool = True,
|
|
|
target_shape: Tuple[int, int, int] = (128, 128, 128),
|
|
|
) -> Tuple[str, List[float]]:
|
|
|
"""
|
|
|
Convenience function to run inference on one NIfTI + text string.
|
|
|
Returns (pred_label, class_probs).
|
|
|
"""
|
|
|
model.eval()
|
|
|
mri_path = Path(mri_path)
|
|
|
|
|
|
|
|
|
vol = nib.load(str(mri_path)).get_fdata().astype(np.float32)
|
|
|
v = (vol - vol.mean()) / (vol.std() + 1e-8)
|
|
|
v = (v - v.min()) / (v.max() - v.min() + 1e-8)
|
|
|
v = center_crop_or_pad(v, target_shape)
|
|
|
|
|
|
|
|
|
img_t = torch.from_numpy(v).unsqueeze(0)
|
|
|
trip_pil = volume_to_triptych(img_t)
|
|
|
preprocess = model.preprocess_val if use_val_preprocess else model.preprocess_train
|
|
|
img_clip = preprocess(trip_pil).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
tokenizer = model.tokenizer_fn
|
|
|
txt_tok = tokenizer([text]).to(device)
|
|
|
|
|
|
|
|
|
logits = model(img_clip, txt_tok)
|
|
|
probs = torch.softmax(logits, dim=-1)[0].cpu().tolist()
|
|
|
pred_idx = int(torch.argmax(logits, dim=-1).item())
|
|
|
pred_label = IDX2LABEL[pred_idx]
|
|
|
return pred_label, probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
weights_dir = "./"
|
|
|
nifti_path = "/path/to/sample_brain.nii.gz"
|
|
|
text_input = "Patient shows mild memory impairment and hippocampal atrophy."
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model = BiomedClipClassifier.from_pretrained(weights_dir, device=device)
|
|
|
|
|
|
pred, probs = predict_from_paths(model, nifti_path, text_input, device=device)
|
|
|
print("Prediction:", pred)
|
|
|
print("Probabilities [CN, MCI, Dementia]:", [round(p, 4) for p in probs])
|
|
|
|