HyperClinical / src /demo_backend /modeling.py
salmasoma
Extract HF foundation embeddings and feed classifier
d242024
from __future__ import annotations
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
import logging
from .backbones import lightweight_backbones
from .constants import CLASS_NAMES
from .neurofusion.config import ModelConfig
from .neurofusion.neurofusion import NeuroFusionModel
logger = logging.getLogger(__name__)
def _checkpoint_prefers_lightweight(model_state_dict: Dict[str, torch.Tensor]) -> bool:
"""Infer whether checkpoint was trained with lightweight fallback backbones."""
if "visual_encoder.backbone.0.weight" in model_state_dict:
return True
visual_proj = model_state_dict.get("visual_encoder.projector.0.weight")
if hasattr(visual_proj, "shape") and len(visual_proj.shape) == 2:
# SigLIP/MedSigLIP models use much larger hidden dims (e.g., 768/1152+).
if int(visual_proj.shape[1]) <= 512:
return True
has_lm_backbone = any(k.startswith("clinical_encoder.lm_backbone.") for k in model_state_dict)
clinical_proj = model_state_dict.get("clinical_encoder.projector.0.weight")
if not has_lm_backbone and hasattr(clinical_proj, "shape") and len(clinical_proj.shape) == 2:
if int(clinical_proj.shape[1]) <= 1024:
return True
return False
def resolve_device(device: str = "auto") -> torch.device:
if device != "auto":
return torch.device(device)
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def load_model_from_checkpoint(
checkpoint_path: Path,
device: torch.device,
force_lightweight_backbones: bool = True,
):
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
model_state = ckpt["model_state_dict"]
model_cfg = ModelConfig()
ckpt_model_cfg = ckpt.get("config", {}).get("model", {})
for key, value in ckpt_model_cfg.items():
if hasattr(model_cfg, key):
setattr(model_cfg, key, value)
checkpoint_prefers_lightweight = _checkpoint_prefers_lightweight(model_state)
use_lightweight = force_lightweight_backbones or checkpoint_prefers_lightweight
if checkpoint_prefers_lightweight and not force_lightweight_backbones:
logger.warning(
"Checkpoint appears to use lightweight backbones; overriding foundation backbone mode for compatibility."
)
with lightweight_backbones(use_lightweight):
model = NeuroFusionModel(model_cfg)
try:
model.load_state_dict(model_state, strict=True)
except RuntimeError as exc:
if use_lightweight:
raise
logger.warning(
"Strict checkpoint load failed with foundation backbones; retrying with lightweight backbones."
)
with lightweight_backbones(True):
fallback_model = NeuroFusionModel(model_cfg)
fallback_model.load_state_dict(model_state, strict=True)
model = fallback_model
model.to(device)
model.eval()
return model, ckpt, model_cfg
def _resize_embedding_to_dim(embedding: torch.Tensor, target_dim: int) -> torch.Tensor:
if embedding.ndim != 2:
raise ValueError(f"Expected 2D embedding tensor, got shape {tuple(embedding.shape)}")
if int(embedding.shape[1]) == int(target_dim):
return embedding
resized = F.interpolate(
embedding.unsqueeze(1),
size=int(target_dim),
mode="linear",
align_corners=False,
).squeeze(1)
return resized
@torch.no_grad()
def predict_single(
model,
mri: torch.Tensor,
avra: torch.Tensor,
clinical: torch.Tensor,
narrative: str,
siglib_embedding: Optional[torch.Tensor] = None,
gemma_embedding: Optional[torch.Tensor] = None,
) -> Dict:
kwargs = {
"avra_scores": avra,
"clinical_features": clinical,
}
if siglib_embedding is not None:
visual_in = int(model.visual_encoder.projector[0].in_features)
kwargs["siglib_embedding"] = _resize_embedding_to_dim(
siglib_embedding.to(mri.device).float(), visual_in
)
else:
kwargs["mri_slices"] = mri
if gemma_embedding is not None:
clinical_in = int(model.clinical_encoder.projector[0].in_features)
kwargs["gemma_embedding"] = _resize_embedding_to_dim(
gemma_embedding.to(mri.device).float(), clinical_in
)
elif getattr(model.clinical_encoder, "lm_backbone", None) is None:
embed_dim = int(model.clinical_encoder.embed_dim)
gemma = torch.zeros((mri.shape[0], embed_dim), device=mri.device)
kwargs["gemma_embedding"] = gemma
else:
kwargs["clinical_narratives"] = [narrative]
outputs = model(**kwargs)
probs = F.softmax(outputs["logits"], dim=1)[0]
pred_idx = int(torch.argmax(probs).item())
class_probs = {CLASS_NAMES[i]: float(probs[i].item()) for i in range(len(CLASS_NAMES))}
return {
"predicted_class_index": pred_idx,
"predicted_class_name": CLASS_NAMES[pred_idx],
"class_probabilities": class_probs,
}