cardio-deploy
Deploy CardioScan inference 2026-04-23T18:11:52Z
3f984f1
from __future__ import annotations
from typing import List
import torch
import torch.nn as nn
import torchxrayvision as xrv
# ---------------------------------------------------------------------------
# RAD-DINO wrapper
# ---------------------------------------------------------------------------
class RadDinoWrapper(nn.Module):
"""microsoft/rad-dino β€” DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays.
Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier``
contract used by every other backbone, so freeze helpers and the two-stage
optimiser work without modification.
Architecture
────────────
.features β€” the full Dinov2Model (embeddings + 12 transformer blocks + layernorm)
.classifier β€” MLP head on **[CLS βˆ₯ mean(patch tokens)]** (1536β†’256) β†’ GELU β†’
Dropout(0.3) β†’ Linear(256β†’1)
Forward pass
────────────
x : (B, 3, H, W) MIMIC-CXR-normalised tensor, any multiple of 14 px.
Recommended resolution: 518 Γ— 518 (native: 37 Γ— 37 patches at 14 px).
Pooling: CLS token concatenated with mean of patch tokens (excludes CLS).
Returns (B, 1) logit tensor; ``cardio_logit`` squeezes to (B,).
Freeze / unfreeze
─────────────────
freeze_backbone() β†’ freezes .features; sets _backbone_frozen=True so
.train() keeps the backbone in eval() mode.
partial_unfreeze(N) β†’ unfreeze last (12 βˆ’ N) blocks + layernorm;
embeddings + first N blocks stay frozen.
"""
def __init__(self) -> None:
super().__init__()
from transformers import AutoModel # lazy β€” only loaded when this backbone is used
dinov2 = AutoModel.from_pretrained("microsoft/rad-dino")
self.features = dinov2
hidden = dinov2.config.hidden_size # 768 for ViT-B
self._head_in = hidden * 2 # CLS + mean(patch tokens)
self.classifier = nn.Sequential(
nn.Linear(self._head_in, 256),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
)
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.zeros_(m.bias)
self._backbone_frozen: bool = False
def train(self, mode: bool = True) -> "RadDinoWrapper":
super().train(mode)
# While the backbone is frozen keep it in eval() so its internal
# Dropout / LayerScale layers don't change during head warmup.
if mode and self._backbone_frozen:
self.features.eval()
return self
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.features(pixel_values=x) # Dinov2ModelOutput
h = out.last_hidden_state # (B, 1 + n_patches, 768)
cls = h[:, 0]
patch_mean = h[:, 1:].mean(dim=1)
z = torch.cat([cls, patch_mean], dim=-1)
return self.classifier(z) # (B, 1)
# ---------------------------------------------------------------------------
# Backbone factory
# ---------------------------------------------------------------------------
def build_model(backbone: str | None = None) -> nn.Module:
"""Build a backbone model for Cardiomegaly classification.
backbone options (also set via CFG.backbone):
"densenet121" β€” torchxrayvision DenseNet-121, pretrained on ~1M chest
X-rays; outputs raw Cardiomegaly logit via pathology index.
"rad-dino" β€” microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M
chest X-rays (HuggingFace); 518Γ—518 recommended input.
"mobilenet_v3_large" β€” torchvision MobileNetV3-Large (ImageNet); final linear
replaced with a single-output head.
"efficientnet_b0" β€” torchvision EfficientNet-B0 (ImageNet); same replacement.
"efficientnet_b3" β€” torchvision EfficientNet-B3 (ImageNet); same replacement.
All returned models expose .features and .classifier so that freeze_backbone()
and the two-stage optimizer in train_one_seed() work unchanged.
Input tensor format differs by backbone β€” use dataset.get_normalize_fn(backbone).
"""
from src.config import CFG # lazy to avoid circular import at module load
backbone = backbone or CFG.backbone
if backbone in ("densenet121", "densenet121-res224-all"):
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model.op_threshs = None # raw logits at every output
model.apply_sigmoid = False # belt + suspenders
return model
if backbone == "rad-dino":
return RadDinoWrapper()
import torchvision.models as tvm
if backbone == "mobilenet_v3_large":
model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
if backbone in ("efficientnet_b0", "efficientnet_b3"):
if backbone == "efficientnet_b0":
model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1)
else:
model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
raise ValueError(
f"Unknown backbone: {backbone!r}. "
"Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3"
)
def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
"""Forward pass returning a (B,) tensor of raw logits for Cardiomegaly.
For torchxrayvision DenseNet the logit is extracted from the pathology head.
For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model
outputs (B, 1) which is squeezed to (B,).
"""
if isinstance(model, xrv.models.DenseNet):
out = model(x) # (B, num_pathologies)
idx = model.pathologies.index("Cardiomegaly")
return out[:, idx]
return model(x).squeeze(1) # (B, 1) β†’ (B,)
# ---------------------------------------------------------------------------
# Backbone management helpers
# ---------------------------------------------------------------------------
def freeze_backbone(model: nn.Module) -> nn.Module:
"""Freeze all params in .features; keep .classifier trainable."""
for p in model.features.parameters():
p.requires_grad = False
for p in model.classifier.parameters():
p.requires_grad = True
if isinstance(model, RadDinoWrapper):
model._backbone_frozen = True
model.features.eval() # prevent LayerScale/Dropout updates while frozen
return model
def unfreeze_all(model: nn.Module) -> nn.Module:
"""Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze."""
for p in model.parameters():
p.requires_grad = True
return model
# DenseNet-121 block groups: (block_name, transition_name) for blocks 1–4
_DENSENET_BLOCK_GROUPS = [
("denseblock1", "transition1"),
("denseblock2", "transition2"),
("denseblock3", "transition3"),
("denseblock4", "norm5"),
]
def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module:
"""Selectively unfreeze the model for stage-2 fine-tuning.
frozen_blocks β€” how many feature blocks to keep frozen:
0 β†’ unfreeze everything (same as unfreeze_all)
DenseNet-121 (4 dense block groups):
1 β†’ keep denseblock1 (+transition1) frozen
2 β†’ keep denseblock1–2 frozen
3 β†’ keep denseblock1–3 frozen
4 β†’ keep all dense blocks frozen (only classifier trains)
RAD-DINO / ViT-B (12 transformer blocks):
1–12 β†’ keep embeddings + first N transformer blocks frozen
(last 12βˆ’N blocks + layernorm are unfrozen)
β‰₯12 β†’ keep all transformer blocks frozen (only classifier trains)
torchvision models (MobileNet, EfficientNet):
N β†’ freeze first N indexed children of model.features.
"""
for p in model.parameters():
p.requires_grad = True
if frozen_blocks <= 0:
return model
if isinstance(model, xrv.models.DenseNet):
frozen_names: set[str] = set()
for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))):
frozen_names.update(_DENSENET_BLOCK_GROUPS[i])
for name, module in model.features.named_children():
if name in frozen_names:
for p in module.parameters():
p.requires_grad = False
elif isinstance(model, RadDinoWrapper):
# Always freeze the patch/position embeddings
for p in model.features.embeddings.parameters():
p.requires_grad = False
# Freeze the first `frozen_blocks` transformer blocks
encoder_layers = model.features.encoder.layer
for block in encoder_layers[:frozen_blocks]:
for p in block.parameters():
p.requires_grad = False
# Some blocks are now trainable β€” allow backbone to go back into train()
model._backbone_frozen = False
else:
for module in list(model.features.children())[:frozen_blocks]:
for p in module.parameters():
p.requires_grad = False
return model
def trainable_params(model: nn.Module) -> List[nn.Parameter]:
"""List of parameters with `requires_grad=True` (for optimiser construction)."""
return [p for p in model.parameters() if p.requires_grad]