import torch import torch.nn as nn import torch.nn.functional as F try: import timm except ImportError as e: raise ImportError( "timm is required for models/vit.py. Install with: pip install timm" ) from e class ViTSegmentationModel(nn.Module): """ Simple ViT segmentation model using a timm Vision Transformer backbone. The model: image -> ViT patch tokens -> reshape to feature map -> conv head -> upsample Output: logits of shape [B, num_classes, H, W] For binary vessel segmentation: num_classes = 1 For multi-class lesion segmentation: num_classes = number of lesion/background classes """ def __init__( self, model_name="vit_base_patch16_224", num_classes=1, pretrained=True, in_chans=3, img_size=512, decoder_dim=256, dropout=0.0, ): super().__init__() self.model_name = model_name self.num_classes = num_classes self.img_size = img_size self.backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0, global_pool="", in_chans=in_chans, img_size=img_size, ) self.embed_dim = self.backbone.num_features self.patch_size = self.backbone.patch_embed.patch_size if isinstance(self.patch_size, tuple): self.patch_size = self.patch_size[0] self.decoder = nn.Sequential( nn.Conv2d(self.embed_dim, decoder_dim, kernel_size=1), nn.BatchNorm2d(decoder_dim), nn.ReLU(inplace=True), nn.Dropout2d(dropout), nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1), nn.BatchNorm2d(decoder_dim), nn.ReLU(inplace=True), nn.Conv2d(decoder_dim, num_classes, kernel_size=1), ) def forward_features_as_map(self, x): """ Convert ViT patch tokens into a spatial feature map. Input: x: [B, C, H, W] Output: feature_map: [B, embed_dim, H // patch_size, W // patch_size] """ b, _, h, w = x.shape tokens = self.backbone.forward_features(x) # Some timm models return a tuple/list. Usually the first item is token features. if isinstance(tokens, (tuple, list)): tokens = tokens[0] # For standard ViT: # tokens: [B, 1 + num_patches, C], where the first token is CLS. if tokens.ndim == 3: expected_num_patches = (h // self.patch_size) * (w // self.patch_size) if tokens.shape[1] == expected_num_patches + 1: tokens = tokens[:, 1:, :] # remove CLS token feature_h = h // self.patch_size feature_w = w // self.patch_size tokens = tokens.transpose(1, 2) feature_map = tokens.reshape(b, self.embed_dim, feature_h, feature_w) # Some backbones may already return [B, C, H, W]. elif tokens.ndim == 4: feature_map = tokens else: raise RuntimeError(f"Unexpected ViT feature shape: {tokens.shape}") return feature_map def forward(self, x): input_size = x.shape[-2:] feature_map = self.forward_features_as_map(x) logits = self.decoder(feature_map) logits = F.interpolate( logits, size=input_size, mode="bilinear", align_corners=False, ) return logits def build_vit( variant="base", num_classes=1, pretrained=True, in_chans=3, img_size=512, decoder_dim=256, dropout=0.0, ): """ Build a timm ViT segmentation model. Parameters ---------- variant: One of: "tiny" "small" "base" "large" Or directly pass a timm model name, e.g.: "vit_base_patch16_224" "vit_small_patch16_224" "vit_large_patch16_224" num_classes: Number of output channels. Binary segmentation: num_classes=1 Multi-class segmentation: num_classes=N pretrained: Whether to load ImageNet-pretrained timm weights. img_size: Input image size. For DRIVE, 512 is a reasonable default. Returns ------- model: ViTSegmentationModel """ variants = { "tiny": "vit_tiny_patch16_224", "small": "vit_small_patch16_224", "base": "vit_base_patch16_224", "large": "vit_large_patch16_224", } model_name = variants.get(variant, variant) model = ViTSegmentationModel( model_name=model_name, num_classes=num_classes, pretrained=pretrained, in_chans=in_chans, img_size=img_size, decoder_dim=decoder_dim, dropout=dropout, ) return model if __name__ == "__main__": # Smoke test: # python models/vit.py device = "cuda" if torch.cuda.is_available() else "cpu" model = build_vit( variant="base", num_classes=1, pretrained=False, img_size=512, ).to(device) x = torch.randn(2, 3, 512, 512).to(device) with torch.no_grad(): y = model(x) print("Model:", model.model_name) print("Input shape:", x.shape) print("Output shape:", y.shape) print("Output min/max:", y.min().item(), y.max().item()) assert y.shape == (2, 1, 512, 512) print("Smoke test passed.")