Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import logging | |
| from .backbones import lightweight_backbones | |
| from .constants import CLASS_NAMES | |
| from .neurofusion.config import ModelConfig | |
| from .neurofusion.neurofusion import NeuroFusionModel | |
| logger = logging.getLogger(__name__) | |
| def _checkpoint_prefers_lightweight(model_state_dict: Dict[str, torch.Tensor]) -> bool: | |
| """Infer whether checkpoint was trained with lightweight fallback backbones.""" | |
| if "visual_encoder.backbone.0.weight" in model_state_dict: | |
| return True | |
| visual_proj = model_state_dict.get("visual_encoder.projector.0.weight") | |
| if hasattr(visual_proj, "shape") and len(visual_proj.shape) == 2: | |
| # SigLIP/MedSigLIP models use much larger hidden dims (e.g., 768/1152+). | |
| if int(visual_proj.shape[1]) <= 512: | |
| return True | |
| has_lm_backbone = any(k.startswith("clinical_encoder.lm_backbone.") for k in model_state_dict) | |
| clinical_proj = model_state_dict.get("clinical_encoder.projector.0.weight") | |
| if not has_lm_backbone and hasattr(clinical_proj, "shape") and len(clinical_proj.shape) == 2: | |
| if int(clinical_proj.shape[1]) <= 1024: | |
| return True | |
| return False | |
| def resolve_device(device: str = "auto") -> torch.device: | |
| if device != "auto": | |
| return torch.device(device) | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def load_model_from_checkpoint( | |
| checkpoint_path: Path, | |
| device: torch.device, | |
| force_lightweight_backbones: bool = True, | |
| ): | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| model_state = ckpt["model_state_dict"] | |
| model_cfg = ModelConfig() | |
| ckpt_model_cfg = ckpt.get("config", {}).get("model", {}) | |
| for key, value in ckpt_model_cfg.items(): | |
| if hasattr(model_cfg, key): | |
| setattr(model_cfg, key, value) | |
| checkpoint_prefers_lightweight = _checkpoint_prefers_lightweight(model_state) | |
| use_lightweight = force_lightweight_backbones or checkpoint_prefers_lightweight | |
| if checkpoint_prefers_lightweight and not force_lightweight_backbones: | |
| logger.warning( | |
| "Checkpoint appears to use lightweight backbones; overriding foundation backbone mode for compatibility." | |
| ) | |
| with lightweight_backbones(use_lightweight): | |
| model = NeuroFusionModel(model_cfg) | |
| try: | |
| model.load_state_dict(model_state, strict=True) | |
| except RuntimeError as exc: | |
| if use_lightweight: | |
| raise | |
| logger.warning( | |
| "Strict checkpoint load failed with foundation backbones; retrying with lightweight backbones." | |
| ) | |
| with lightweight_backbones(True): | |
| fallback_model = NeuroFusionModel(model_cfg) | |
| fallback_model.load_state_dict(model_state, strict=True) | |
| model = fallback_model | |
| model.to(device) | |
| model.eval() | |
| return model, ckpt, model_cfg | |
| def _resize_embedding_to_dim(embedding: torch.Tensor, target_dim: int) -> torch.Tensor: | |
| if embedding.ndim != 2: | |
| raise ValueError(f"Expected 2D embedding tensor, got shape {tuple(embedding.shape)}") | |
| if int(embedding.shape[1]) == int(target_dim): | |
| return embedding | |
| resized = F.interpolate( | |
| embedding.unsqueeze(1), | |
| size=int(target_dim), | |
| mode="linear", | |
| align_corners=False, | |
| ).squeeze(1) | |
| return resized | |
| def predict_single( | |
| model, | |
| mri: torch.Tensor, | |
| avra: torch.Tensor, | |
| clinical: torch.Tensor, | |
| narrative: str, | |
| siglib_embedding: Optional[torch.Tensor] = None, | |
| gemma_embedding: Optional[torch.Tensor] = None, | |
| ) -> Dict: | |
| kwargs = { | |
| "avra_scores": avra, | |
| "clinical_features": clinical, | |
| } | |
| if siglib_embedding is not None: | |
| visual_in = int(model.visual_encoder.projector[0].in_features) | |
| kwargs["siglib_embedding"] = _resize_embedding_to_dim( | |
| siglib_embedding.to(mri.device).float(), visual_in | |
| ) | |
| else: | |
| kwargs["mri_slices"] = mri | |
| if gemma_embedding is not None: | |
| clinical_in = int(model.clinical_encoder.projector[0].in_features) | |
| kwargs["gemma_embedding"] = _resize_embedding_to_dim( | |
| gemma_embedding.to(mri.device).float(), clinical_in | |
| ) | |
| elif getattr(model.clinical_encoder, "lm_backbone", None) is None: | |
| embed_dim = int(model.clinical_encoder.embed_dim) | |
| gemma = torch.zeros((mri.shape[0], embed_dim), device=mri.device) | |
| kwargs["gemma_embedding"] = gemma | |
| else: | |
| kwargs["clinical_narratives"] = [narrative] | |
| outputs = model(**kwargs) | |
| probs = F.softmax(outputs["logits"], dim=1)[0] | |
| pred_idx = int(torch.argmax(probs).item()) | |
| class_probs = {CLASS_NAMES[i]: float(probs[i].item()) for i in range(len(CLASS_NAMES))} | |
| return { | |
| "predicted_class_index": pred_idx, | |
| "predicted_class_name": CLASS_NAMES[pred_idx], | |
| "class_probabilities": class_probs, | |
| } | |