File size: 9,913 Bytes
3f984f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | from __future__ import annotations
from typing import List
import torch
import torch.nn as nn
import torchxrayvision as xrv
# ---------------------------------------------------------------------------
# RAD-DINO wrapper
# ---------------------------------------------------------------------------
class RadDinoWrapper(nn.Module):
"""microsoft/rad-dino β DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays.
Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier``
contract used by every other backbone, so freeze helpers and the two-stage
optimiser work without modification.
Architecture
ββββββββββββ
.features β the full Dinov2Model (embeddings + 12 transformer blocks + layernorm)
.classifier β MLP head on **[CLS β₯ mean(patch tokens)]** (1536β256) β GELU β
Dropout(0.3) β Linear(256β1)
Forward pass
ββββββββββββ
x : (B, 3, H, W) MIMIC-CXR-normalised tensor, any multiple of 14 px.
Recommended resolution: 518 Γ 518 (native: 37 Γ 37 patches at 14 px).
Pooling: CLS token concatenated with mean of patch tokens (excludes CLS).
Returns (B, 1) logit tensor; ``cardio_logit`` squeezes to (B,).
Freeze / unfreeze
βββββββββββββββββ
freeze_backbone() β freezes .features; sets _backbone_frozen=True so
.train() keeps the backbone in eval() mode.
partial_unfreeze(N) β unfreeze last (12 β N) blocks + layernorm;
embeddings + first N blocks stay frozen.
"""
def __init__(self) -> None:
super().__init__()
from transformers import AutoModel # lazy β only loaded when this backbone is used
dinov2 = AutoModel.from_pretrained("microsoft/rad-dino")
self.features = dinov2
hidden = dinov2.config.hidden_size # 768 for ViT-B
self._head_in = hidden * 2 # CLS + mean(patch tokens)
self.classifier = nn.Sequential(
nn.Linear(self._head_in, 256),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
)
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.zeros_(m.bias)
self._backbone_frozen: bool = False
def train(self, mode: bool = True) -> "RadDinoWrapper":
super().train(mode)
# While the backbone is frozen keep it in eval() so its internal
# Dropout / LayerScale layers don't change during head warmup.
if mode and self._backbone_frozen:
self.features.eval()
return self
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.features(pixel_values=x) # Dinov2ModelOutput
h = out.last_hidden_state # (B, 1 + n_patches, 768)
cls = h[:, 0]
patch_mean = h[:, 1:].mean(dim=1)
z = torch.cat([cls, patch_mean], dim=-1)
return self.classifier(z) # (B, 1)
# ---------------------------------------------------------------------------
# Backbone factory
# ---------------------------------------------------------------------------
def build_model(backbone: str | None = None) -> nn.Module:
"""Build a backbone model for Cardiomegaly classification.
backbone options (also set via CFG.backbone):
"densenet121" β torchxrayvision DenseNet-121, pretrained on ~1M chest
X-rays; outputs raw Cardiomegaly logit via pathology index.
"rad-dino" β microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M
chest X-rays (HuggingFace); 518Γ518 recommended input.
"mobilenet_v3_large" β torchvision MobileNetV3-Large (ImageNet); final linear
replaced with a single-output head.
"efficientnet_b0" β torchvision EfficientNet-B0 (ImageNet); same replacement.
"efficientnet_b3" β torchvision EfficientNet-B3 (ImageNet); same replacement.
All returned models expose .features and .classifier so that freeze_backbone()
and the two-stage optimizer in train_one_seed() work unchanged.
Input tensor format differs by backbone β use dataset.get_normalize_fn(backbone).
"""
from src.config import CFG # lazy to avoid circular import at module load
backbone = backbone or CFG.backbone
if backbone in ("densenet121", "densenet121-res224-all"):
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model.op_threshs = None # raw logits at every output
model.apply_sigmoid = False # belt + suspenders
return model
if backbone == "rad-dino":
return RadDinoWrapper()
import torchvision.models as tvm
if backbone == "mobilenet_v3_large":
model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
if backbone in ("efficientnet_b0", "efficientnet_b3"):
if backbone == "efficientnet_b0":
model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1)
else:
model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
raise ValueError(
f"Unknown backbone: {backbone!r}. "
"Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3"
)
def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
"""Forward pass returning a (B,) tensor of raw logits for Cardiomegaly.
For torchxrayvision DenseNet the logit is extracted from the pathology head.
For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model
outputs (B, 1) which is squeezed to (B,).
"""
if isinstance(model, xrv.models.DenseNet):
out = model(x) # (B, num_pathologies)
idx = model.pathologies.index("Cardiomegaly")
return out[:, idx]
return model(x).squeeze(1) # (B, 1) β (B,)
# ---------------------------------------------------------------------------
# Backbone management helpers
# ---------------------------------------------------------------------------
def freeze_backbone(model: nn.Module) -> nn.Module:
"""Freeze all params in .features; keep .classifier trainable."""
for p in model.features.parameters():
p.requires_grad = False
for p in model.classifier.parameters():
p.requires_grad = True
if isinstance(model, RadDinoWrapper):
model._backbone_frozen = True
model.features.eval() # prevent LayerScale/Dropout updates while frozen
return model
def unfreeze_all(model: nn.Module) -> nn.Module:
"""Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze."""
for p in model.parameters():
p.requires_grad = True
return model
# DenseNet-121 block groups: (block_name, transition_name) for blocks 1β4
_DENSENET_BLOCK_GROUPS = [
("denseblock1", "transition1"),
("denseblock2", "transition2"),
("denseblock3", "transition3"),
("denseblock4", "norm5"),
]
def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module:
"""Selectively unfreeze the model for stage-2 fine-tuning.
frozen_blocks β how many feature blocks to keep frozen:
0 β unfreeze everything (same as unfreeze_all)
DenseNet-121 (4 dense block groups):
1 β keep denseblock1 (+transition1) frozen
2 β keep denseblock1β2 frozen
3 β keep denseblock1β3 frozen
4 β keep all dense blocks frozen (only classifier trains)
RAD-DINO / ViT-B (12 transformer blocks):
1β12 β keep embeddings + first N transformer blocks frozen
(last 12βN blocks + layernorm are unfrozen)
β₯12 β keep all transformer blocks frozen (only classifier trains)
torchvision models (MobileNet, EfficientNet):
N β freeze first N indexed children of model.features.
"""
for p in model.parameters():
p.requires_grad = True
if frozen_blocks <= 0:
return model
if isinstance(model, xrv.models.DenseNet):
frozen_names: set[str] = set()
for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))):
frozen_names.update(_DENSENET_BLOCK_GROUPS[i])
for name, module in model.features.named_children():
if name in frozen_names:
for p in module.parameters():
p.requires_grad = False
elif isinstance(model, RadDinoWrapper):
# Always freeze the patch/position embeddings
for p in model.features.embeddings.parameters():
p.requires_grad = False
# Freeze the first `frozen_blocks` transformer blocks
encoder_layers = model.features.encoder.layer
for block in encoder_layers[:frozen_blocks]:
for p in block.parameters():
p.requires_grad = False
# Some blocks are now trainable β allow backbone to go back into train()
model._backbone_frozen = False
else:
for module in list(model.features.children())[:frozen_blocks]:
for p in module.parameters():
p.requires_grad = False
return model
def trainable_params(model: nn.Module) -> List[nn.Parameter]:
"""List of parameters with `requires_grad=True` (for optimiser construction)."""
return [p for p in model.parameters() if p.requires_grad]
|