| from __future__ import annotations |
|
|
| from typing import List |
|
|
| import torch |
| import torch.nn as nn |
| import torchxrayvision as xrv |
|
|
|
|
| |
| |
| |
| class RadDinoWrapper(nn.Module): |
| """microsoft/rad-dino β DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays. |
| |
| Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier`` |
| contract used by every other backbone, so freeze helpers and the two-stage |
| optimiser work without modification. |
| |
| Architecture |
| ββββββββββββ |
| .features β the full Dinov2Model (embeddings + 12 transformer blocks + layernorm) |
| .classifier β MLP head on **[CLS β₯ mean(patch tokens)]** (1536β256) β GELU β |
| Dropout(0.3) β Linear(256β1) |
| |
| Forward pass |
| ββββββββββββ |
| x : (B, 3, H, W) MIMIC-CXR-normalised tensor, any multiple of 14 px. |
| Recommended resolution: 518 Γ 518 (native: 37 Γ 37 patches at 14 px). |
| Pooling: CLS token concatenated with mean of patch tokens (excludes CLS). |
| Returns (B, 1) logit tensor; ``cardio_logit`` squeezes to (B,). |
| |
| Freeze / unfreeze |
| βββββββββββββββββ |
| freeze_backbone() β freezes .features; sets _backbone_frozen=True so |
| .train() keeps the backbone in eval() mode. |
| partial_unfreeze(N) β unfreeze last (12 β N) blocks + layernorm; |
| embeddings + first N blocks stay frozen. |
| """ |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| from transformers import AutoModel |
| dinov2 = AutoModel.from_pretrained("microsoft/rad-dino") |
| self.features = dinov2 |
| hidden = dinov2.config.hidden_size |
| self._head_in = hidden * 2 |
| self.classifier = nn.Sequential( |
| nn.Linear(self._head_in, 256), |
| nn.GELU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, 1), |
| ) |
| for m in self.classifier.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| nn.init.zeros_(m.bias) |
| self._backbone_frozen: bool = False |
|
|
| def train(self, mode: bool = True) -> "RadDinoWrapper": |
| super().train(mode) |
| |
| |
| if mode and self._backbone_frozen: |
| self.features.eval() |
| return self |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = self.features(pixel_values=x) |
| h = out.last_hidden_state |
| cls = h[:, 0] |
| patch_mean = h[:, 1:].mean(dim=1) |
| z = torch.cat([cls, patch_mean], dim=-1) |
| return self.classifier(z) |
|
|
|
|
| |
| |
| |
| def build_model(backbone: str | None = None) -> nn.Module: |
| """Build a backbone model for Cardiomegaly classification. |
| |
| backbone options (also set via CFG.backbone): |
| "densenet121" β torchxrayvision DenseNet-121, pretrained on ~1M chest |
| X-rays; outputs raw Cardiomegaly logit via pathology index. |
| "rad-dino" β microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M |
| chest X-rays (HuggingFace); 518Γ518 recommended input. |
| "mobilenet_v3_large" β torchvision MobileNetV3-Large (ImageNet); final linear |
| replaced with a single-output head. |
| "efficientnet_b0" β torchvision EfficientNet-B0 (ImageNet); same replacement. |
| "efficientnet_b3" β torchvision EfficientNet-B3 (ImageNet); same replacement. |
| |
| All returned models expose .features and .classifier so that freeze_backbone() |
| and the two-stage optimizer in train_one_seed() work unchanged. |
| Input tensor format differs by backbone β use dataset.get_normalize_fn(backbone). |
| """ |
| from src.config import CFG |
| backbone = backbone or CFG.backbone |
|
|
| if backbone in ("densenet121", "densenet121-res224-all"): |
| model = xrv.models.DenseNet(weights="densenet121-res224-all") |
| model.op_threshs = None |
| model.apply_sigmoid = False |
| return model |
|
|
| if backbone == "rad-dino": |
| return RadDinoWrapper() |
|
|
| import torchvision.models as tvm |
|
|
| if backbone == "mobilenet_v3_large": |
| model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2) |
| in_features = model.classifier[-1].in_features |
| model.classifier[-1] = nn.Linear(in_features, 1) |
| return model |
|
|
| if backbone in ("efficientnet_b0", "efficientnet_b3"): |
| if backbone == "efficientnet_b0": |
| model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1) |
| else: |
| model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1) |
| in_features = model.classifier[-1].in_features |
| model.classifier[-1] = nn.Linear(in_features, 1) |
| return model |
|
|
| raise ValueError( |
| f"Unknown backbone: {backbone!r}. " |
| "Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3" |
| ) |
|
|
|
|
| def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass returning a (B,) tensor of raw logits for Cardiomegaly. |
| |
| For torchxrayvision DenseNet the logit is extracted from the pathology head. |
| For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model |
| outputs (B, 1) which is squeezed to (B,). |
| """ |
| if isinstance(model, xrv.models.DenseNet): |
| out = model(x) |
| idx = model.pathologies.index("Cardiomegaly") |
| return out[:, idx] |
| return model(x).squeeze(1) |
|
|
|
|
| |
| |
| |
| def freeze_backbone(model: nn.Module) -> nn.Module: |
| """Freeze all params in .features; keep .classifier trainable.""" |
| for p in model.features.parameters(): |
| p.requires_grad = False |
| for p in model.classifier.parameters(): |
| p.requires_grad = True |
| if isinstance(model, RadDinoWrapper): |
| model._backbone_frozen = True |
| model.features.eval() |
| return model |
|
|
|
|
| def unfreeze_all(model: nn.Module) -> nn.Module: |
| """Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze.""" |
| for p in model.parameters(): |
| p.requires_grad = True |
| return model |
|
|
|
|
| |
| _DENSENET_BLOCK_GROUPS = [ |
| ("denseblock1", "transition1"), |
| ("denseblock2", "transition2"), |
| ("denseblock3", "transition3"), |
| ("denseblock4", "norm5"), |
| ] |
|
|
|
|
| def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module: |
| """Selectively unfreeze the model for stage-2 fine-tuning. |
| |
| frozen_blocks β how many feature blocks to keep frozen: |
| 0 β unfreeze everything (same as unfreeze_all) |
| |
| DenseNet-121 (4 dense block groups): |
| 1 β keep denseblock1 (+transition1) frozen |
| 2 β keep denseblock1β2 frozen |
| 3 β keep denseblock1β3 frozen |
| 4 β keep all dense blocks frozen (only classifier trains) |
| |
| RAD-DINO / ViT-B (12 transformer blocks): |
| 1β12 β keep embeddings + first N transformer blocks frozen |
| (last 12βN blocks + layernorm are unfrozen) |
| β₯12 β keep all transformer blocks frozen (only classifier trains) |
| |
| torchvision models (MobileNet, EfficientNet): |
| N β freeze first N indexed children of model.features. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = True |
|
|
| if frozen_blocks <= 0: |
| return model |
|
|
| if isinstance(model, xrv.models.DenseNet): |
| frozen_names: set[str] = set() |
| for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))): |
| frozen_names.update(_DENSENET_BLOCK_GROUPS[i]) |
| for name, module in model.features.named_children(): |
| if name in frozen_names: |
| for p in module.parameters(): |
| p.requires_grad = False |
|
|
| elif isinstance(model, RadDinoWrapper): |
| |
| for p in model.features.embeddings.parameters(): |
| p.requires_grad = False |
| |
| encoder_layers = model.features.encoder.layer |
| for block in encoder_layers[:frozen_blocks]: |
| for p in block.parameters(): |
| p.requires_grad = False |
| |
| model._backbone_frozen = False |
|
|
| else: |
| for module in list(model.features.children())[:frozen_blocks]: |
| for p in module.parameters(): |
| p.requires_grad = False |
|
|
| return model |
|
|
|
|
| def trainable_params(model: nn.Module) -> List[nn.Parameter]: |
| """List of parameters with `requires_grad=True` (for optimiser construction).""" |
| return [p for p in model.parameters() if p.requires_grad] |
|
|