cardio-deploy
Deploy CardioScan inference 2026-04-23T12:22:25Z
08a909f
from __future__ import annotations
from typing import List
import torch
import torch.nn as nn
import torchxrayvision as xrv
# ---------------------------------------------------------------------------
# RAD-DINO wrapper
# ---------------------------------------------------------------------------
class RadDinoWrapper(nn.Module):
"""microsoft/rad-dino β€” DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays.
Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier``
contract used by every other backbone, so freeze helpers and the two-stage
optimiser work without modification.
Architecture
────────────
.features β€” the full Dinov2Model (embeddings + 12 transformer blocks + layernorm)
.classifier β€” nn.Linear(hidden_size=768, out_features=1)
Forward pass
────────────
x : (B, 3, H, W) float tensor β€” ImageNet-normalised, any multiple of 14 px.
Recommended resolution: 518 Γ— 518 (native: 37 Γ— 37 patches at 14 px).
Returns (B,) logit tensor.
Freeze / unfreeze
─────────────────
freeze_backbone() β†’ freezes .features (all 12 blocks + embeddings)
partial_unfreeze(N) β†’ unfreeze last (12 βˆ’ N) blocks + layernorm;
keep embeddings + first N blocks frozen.
"""
def __init__(self) -> None:
super().__init__()
from transformers import AutoModel # lazy β€” only loaded when this backbone is used
dinov2 = AutoModel.from_pretrained("microsoft/rad-dino")
self.features = dinov2
self.classifier = nn.Linear(dinov2.config.hidden_size, 1)
nn.init.trunc_normal_(self.classifier.weight, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.features(pixel_values=x) # Dinov2ModelOutput
cls = out.last_hidden_state[:, 0] # CLS token (B, 768)
return self.classifier(cls) # (B, 1)
# ---------------------------------------------------------------------------
# Backbone factory
# ---------------------------------------------------------------------------
def build_model(backbone: str | None = None) -> nn.Module:
"""Build a backbone model for Cardiomegaly classification.
backbone options (also set via CFG.backbone):
"densenet121" β€” torchxrayvision DenseNet-121, pretrained on ~1M chest
X-rays; outputs raw Cardiomegaly logit via pathology index.
"rad-dino" β€” microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M
chest X-rays (HuggingFace); 518Γ—518 recommended input.
"mobilenet_v3_large" β€” torchvision MobileNetV3-Large (ImageNet); final linear
replaced with a single-output head.
"efficientnet_b0" β€” torchvision EfficientNet-B0 (ImageNet); same replacement.
"efficientnet_b3" β€” torchvision EfficientNet-B3 (ImageNet); same replacement.
All returned models expose .features and .classifier so that freeze_backbone()
and the two-stage optimizer in train_one_seed() work unchanged.
Input tensor format differs by backbone β€” use dataset.get_normalize_fn(backbone).
"""
from src.config import CFG # lazy to avoid circular import at module load
backbone = backbone or CFG.backbone
if backbone in ("densenet121", "densenet121-res224-all"):
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model.op_threshs = None # raw logits at every output
model.apply_sigmoid = False # belt + suspenders
return model
if backbone == "rad-dino":
return RadDinoWrapper()
import torchvision.models as tvm
if backbone == "mobilenet_v3_large":
model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
if backbone in ("efficientnet_b0", "efficientnet_b3"):
if backbone == "efficientnet_b0":
model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1)
else:
model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
raise ValueError(
f"Unknown backbone: {backbone!r}. "
"Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3"
)
def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
"""Forward pass returning a (B,) tensor of raw logits for Cardiomegaly.
For torchxrayvision DenseNet the logit is extracted from the pathology head.
For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model
outputs (B, 1) which is squeezed to (B,).
"""
if isinstance(model, xrv.models.DenseNet):
out = model(x) # (B, num_pathologies)
idx = model.pathologies.index("Cardiomegaly")
return out[:, idx]
return model(x).squeeze(1) # (B, 1) β†’ (B,)
# ---------------------------------------------------------------------------
# Backbone management helpers
# ---------------------------------------------------------------------------
def freeze_backbone(model: nn.Module) -> nn.Module:
"""Freeze all params in .features; keep .classifier trainable."""
for p in model.features.parameters():
p.requires_grad = False
for p in model.classifier.parameters():
p.requires_grad = True
return model
def unfreeze_all(model: nn.Module) -> nn.Module:
"""Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze."""
for p in model.parameters():
p.requires_grad = True
return model
# DenseNet-121 block groups: (block_name, transition_name) for blocks 1–4
_DENSENET_BLOCK_GROUPS = [
("denseblock1", "transition1"),
("denseblock2", "transition2"),
("denseblock3", "transition3"),
("denseblock4", "norm5"),
]
def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module:
"""Selectively unfreeze the model for stage-2 fine-tuning.
frozen_blocks β€” how many feature blocks to keep frozen:
0 β†’ unfreeze everything (same as unfreeze_all)
DenseNet-121 (4 dense block groups):
1 β†’ keep denseblock1 (+transition1) frozen
2 β†’ keep denseblock1–2 frozen
3 β†’ keep denseblock1–3 frozen
4 β†’ keep all dense blocks frozen (only classifier trains)
RAD-DINO / ViT-B (12 transformer blocks):
1–12 β†’ keep embeddings + first N transformer blocks frozen
(last 12βˆ’N blocks + layernorm are unfrozen)
β‰₯12 β†’ keep all transformer blocks frozen (only classifier trains)
torchvision models (MobileNet, EfficientNet):
N β†’ freeze first N indexed children of model.features.
"""
for p in model.parameters():
p.requires_grad = True
if frozen_blocks <= 0:
return model
if isinstance(model, xrv.models.DenseNet):
frozen_names: set[str] = set()
for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))):
frozen_names.update(_DENSENET_BLOCK_GROUPS[i])
for name, module in model.features.named_children():
if name in frozen_names:
for p in module.parameters():
p.requires_grad = False
elif isinstance(model, RadDinoWrapper):
# Always freeze the patch/position embeddings
for p in model.features.embeddings.parameters():
p.requires_grad = False
# Freeze the first `frozen_blocks` transformer blocks
encoder_layers = model.features.encoder.layer
for block in encoder_layers[:frozen_blocks]:
for p in block.parameters():
p.requires_grad = False
else:
for module in list(model.features.children())[:frozen_blocks]:
for p in module.parameters():
p.requires_grad = False
return model
def trainable_params(model: nn.Module) -> List[nn.Parameter]:
"""List of parameters with `requires_grad=True` (for optimiser construction)."""
return [p for p in model.parameters() if p.requires_grad]