Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torchxrayvision as xrv | |
| # --------------------------------------------------------------------------- | |
| # RAD-DINO wrapper | |
| # --------------------------------------------------------------------------- | |
| class RadDinoWrapper(nn.Module): | |
| """microsoft/rad-dino β DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays. | |
| Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier`` | |
| contract used by every other backbone, so freeze helpers and the two-stage | |
| optimiser work without modification. | |
| Architecture | |
| ββββββββββββ | |
| .features β the full Dinov2Model (embeddings + 12 transformer blocks + layernorm) | |
| .classifier β nn.Linear(hidden_size=768, out_features=1) | |
| Forward pass | |
| ββββββββββββ | |
| x : (B, 3, H, W) float tensor β ImageNet-normalised, any multiple of 14 px. | |
| Recommended resolution: 518 Γ 518 (native: 37 Γ 37 patches at 14 px). | |
| Returns (B,) logit tensor. | |
| Freeze / unfreeze | |
| βββββββββββββββββ | |
| freeze_backbone() β freezes .features (all 12 blocks + embeddings) | |
| partial_unfreeze(N) β unfreeze last (12 β N) blocks + layernorm; | |
| keep embeddings + first N blocks frozen. | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| from transformers import AutoModel # lazy β only loaded when this backbone is used | |
| dinov2 = AutoModel.from_pretrained("microsoft/rad-dino") | |
| self.features = dinov2 | |
| self.classifier = nn.Linear(dinov2.config.hidden_size, 1) | |
| nn.init.trunc_normal_(self.classifier.weight, std=0.02) | |
| nn.init.zeros_(self.classifier.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| out = self.features(pixel_values=x) # Dinov2ModelOutput | |
| cls = out.last_hidden_state[:, 0] # CLS token (B, 768) | |
| return self.classifier(cls) # (B, 1) | |
| # --------------------------------------------------------------------------- | |
| # Backbone factory | |
| # --------------------------------------------------------------------------- | |
| def build_model(backbone: str | None = None) -> nn.Module: | |
| """Build a backbone model for Cardiomegaly classification. | |
| backbone options (also set via CFG.backbone): | |
| "densenet121" β torchxrayvision DenseNet-121, pretrained on ~1M chest | |
| X-rays; outputs raw Cardiomegaly logit via pathology index. | |
| "rad-dino" β microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M | |
| chest X-rays (HuggingFace); 518Γ518 recommended input. | |
| "mobilenet_v3_large" β torchvision MobileNetV3-Large (ImageNet); final linear | |
| replaced with a single-output head. | |
| "efficientnet_b0" β torchvision EfficientNet-B0 (ImageNet); same replacement. | |
| "efficientnet_b3" β torchvision EfficientNet-B3 (ImageNet); same replacement. | |
| All returned models expose .features and .classifier so that freeze_backbone() | |
| and the two-stage optimizer in train_one_seed() work unchanged. | |
| Input tensor format differs by backbone β use dataset.get_normalize_fn(backbone). | |
| """ | |
| from src.config import CFG # lazy to avoid circular import at module load | |
| backbone = backbone or CFG.backbone | |
| if backbone in ("densenet121", "densenet121-res224-all"): | |
| model = xrv.models.DenseNet(weights="densenet121-res224-all") | |
| model.op_threshs = None # raw logits at every output | |
| model.apply_sigmoid = False # belt + suspenders | |
| return model | |
| if backbone == "rad-dino": | |
| return RadDinoWrapper() | |
| import torchvision.models as tvm | |
| if backbone == "mobilenet_v3_large": | |
| model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2) | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, 1) | |
| return model | |
| if backbone in ("efficientnet_b0", "efficientnet_b3"): | |
| if backbone == "efficientnet_b0": | |
| model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1) | |
| else: | |
| model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1) | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, 1) | |
| return model | |
| raise ValueError( | |
| f"Unknown backbone: {backbone!r}. " | |
| "Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3" | |
| ) | |
| def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass returning a (B,) tensor of raw logits for Cardiomegaly. | |
| For torchxrayvision DenseNet the logit is extracted from the pathology head. | |
| For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model | |
| outputs (B, 1) which is squeezed to (B,). | |
| """ | |
| if isinstance(model, xrv.models.DenseNet): | |
| out = model(x) # (B, num_pathologies) | |
| idx = model.pathologies.index("Cardiomegaly") | |
| return out[:, idx] | |
| return model(x).squeeze(1) # (B, 1) β (B,) | |
| # --------------------------------------------------------------------------- | |
| # Backbone management helpers | |
| # --------------------------------------------------------------------------- | |
| def freeze_backbone(model: nn.Module) -> nn.Module: | |
| """Freeze all params in .features; keep .classifier trainable.""" | |
| for p in model.features.parameters(): | |
| p.requires_grad = False | |
| for p in model.classifier.parameters(): | |
| p.requires_grad = True | |
| return model | |
| def unfreeze_all(model: nn.Module) -> nn.Module: | |
| """Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze.""" | |
| for p in model.parameters(): | |
| p.requires_grad = True | |
| return model | |
| # DenseNet-121 block groups: (block_name, transition_name) for blocks 1β4 | |
| _DENSENET_BLOCK_GROUPS = [ | |
| ("denseblock1", "transition1"), | |
| ("denseblock2", "transition2"), | |
| ("denseblock3", "transition3"), | |
| ("denseblock4", "norm5"), | |
| ] | |
| def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module: | |
| """Selectively unfreeze the model for stage-2 fine-tuning. | |
| frozen_blocks β how many feature blocks to keep frozen: | |
| 0 β unfreeze everything (same as unfreeze_all) | |
| DenseNet-121 (4 dense block groups): | |
| 1 β keep denseblock1 (+transition1) frozen | |
| 2 β keep denseblock1β2 frozen | |
| 3 β keep denseblock1β3 frozen | |
| 4 β keep all dense blocks frozen (only classifier trains) | |
| RAD-DINO / ViT-B (12 transformer blocks): | |
| 1β12 β keep embeddings + first N transformer blocks frozen | |
| (last 12βN blocks + layernorm are unfrozen) | |
| β₯12 β keep all transformer blocks frozen (only classifier trains) | |
| torchvision models (MobileNet, EfficientNet): | |
| N β freeze first N indexed children of model.features. | |
| """ | |
| for p in model.parameters(): | |
| p.requires_grad = True | |
| if frozen_blocks <= 0: | |
| return model | |
| if isinstance(model, xrv.models.DenseNet): | |
| frozen_names: set[str] = set() | |
| for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))): | |
| frozen_names.update(_DENSENET_BLOCK_GROUPS[i]) | |
| for name, module in model.features.named_children(): | |
| if name in frozen_names: | |
| for p in module.parameters(): | |
| p.requires_grad = False | |
| elif isinstance(model, RadDinoWrapper): | |
| # Always freeze the patch/position embeddings | |
| for p in model.features.embeddings.parameters(): | |
| p.requires_grad = False | |
| # Freeze the first `frozen_blocks` transformer blocks | |
| encoder_layers = model.features.encoder.layer | |
| for block in encoder_layers[:frozen_blocks]: | |
| for p in block.parameters(): | |
| p.requires_grad = False | |
| else: | |
| for module in list(model.features.children())[:frozen_blocks]: | |
| for p in module.parameters(): | |
| p.requires_grad = False | |
| return model | |
| def trainable_params(model: nn.Module) -> List[nn.Parameter]: | |
| """List of parameters with `requires_grad=True` (for optimiser construction).""" | |
| return [p for p in model.parameters() if p.requires_grad] | |