Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| try: | |
| import timm | |
| except ImportError as e: | |
| raise ImportError( | |
| "timm is required for models/vit.py. Install with: pip install timm" | |
| ) from e | |
| class ViTSegmentationModel(nn.Module): | |
| """ | |
| Simple ViT segmentation model using a timm Vision Transformer backbone. | |
| The model: | |
| image -> ViT patch tokens -> reshape to feature map -> conv head -> upsample | |
| Output: | |
| logits of shape [B, num_classes, H, W] | |
| For binary vessel segmentation: | |
| num_classes = 1 | |
| For multi-class lesion segmentation: | |
| num_classes = number of lesion/background classes | |
| """ | |
| def __init__( | |
| self, | |
| model_name="vit_base_patch16_224", | |
| num_classes=1, | |
| pretrained=True, | |
| in_chans=3, | |
| img_size=512, | |
| decoder_dim=256, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.model_name = model_name | |
| self.num_classes = num_classes | |
| self.img_size = img_size | |
| self.backbone = timm.create_model( | |
| model_name, | |
| pretrained=pretrained, | |
| num_classes=0, | |
| global_pool="", | |
| in_chans=in_chans, | |
| img_size=img_size, | |
| ) | |
| self.embed_dim = self.backbone.num_features | |
| self.patch_size = self.backbone.patch_embed.patch_size | |
| if isinstance(self.patch_size, tuple): | |
| self.patch_size = self.patch_size[0] | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(self.embed_dim, decoder_dim, kernel_size=1), | |
| nn.BatchNorm2d(decoder_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(dropout), | |
| nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(decoder_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(decoder_dim, num_classes, kernel_size=1), | |
| ) | |
| def forward_features_as_map(self, x): | |
| """ | |
| Convert ViT patch tokens into a spatial feature map. | |
| Input: | |
| x: [B, C, H, W] | |
| Output: | |
| feature_map: [B, embed_dim, H // patch_size, W // patch_size] | |
| """ | |
| b, _, h, w = x.shape | |
| tokens = self.backbone.forward_features(x) | |
| # Some timm models return a tuple/list. Usually the first item is token features. | |
| if isinstance(tokens, (tuple, list)): | |
| tokens = tokens[0] | |
| # For standard ViT: | |
| # tokens: [B, 1 + num_patches, C], where the first token is CLS. | |
| if tokens.ndim == 3: | |
| expected_num_patches = (h // self.patch_size) * (w // self.patch_size) | |
| if tokens.shape[1] == expected_num_patches + 1: | |
| tokens = tokens[:, 1:, :] # remove CLS token | |
| feature_h = h // self.patch_size | |
| feature_w = w // self.patch_size | |
| tokens = tokens.transpose(1, 2) | |
| feature_map = tokens.reshape(b, self.embed_dim, feature_h, feature_w) | |
| # Some backbones may already return [B, C, H, W]. | |
| elif tokens.ndim == 4: | |
| feature_map = tokens | |
| else: | |
| raise RuntimeError(f"Unexpected ViT feature shape: {tokens.shape}") | |
| return feature_map | |
| def forward(self, x): | |
| input_size = x.shape[-2:] | |
| feature_map = self.forward_features_as_map(x) | |
| logits = self.decoder(feature_map) | |
| logits = F.interpolate( | |
| logits, | |
| size=input_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return logits | |
| def build_vit( | |
| variant="base", | |
| num_classes=1, | |
| pretrained=True, | |
| in_chans=3, | |
| img_size=512, | |
| decoder_dim=256, | |
| dropout=0.0, | |
| ): | |
| """ | |
| Build a timm ViT segmentation model. | |
| Parameters | |
| ---------- | |
| variant: | |
| One of: | |
| "tiny" | |
| "small" | |
| "base" | |
| "large" | |
| Or directly pass a timm model name, e.g.: | |
| "vit_base_patch16_224" | |
| "vit_small_patch16_224" | |
| "vit_large_patch16_224" | |
| num_classes: | |
| Number of output channels. | |
| Binary segmentation: | |
| num_classes=1 | |
| Multi-class segmentation: | |
| num_classes=N | |
| pretrained: | |
| Whether to load ImageNet-pretrained timm weights. | |
| img_size: | |
| Input image size. For DRIVE, 512 is a reasonable default. | |
| Returns | |
| ------- | |
| model: | |
| ViTSegmentationModel | |
| """ | |
| variants = { | |
| "tiny": "vit_tiny_patch16_224", | |
| "small": "vit_small_patch16_224", | |
| "base": "vit_base_patch16_224", | |
| "large": "vit_large_patch16_224", | |
| } | |
| model_name = variants.get(variant, variant) | |
| model = ViTSegmentationModel( | |
| model_name=model_name, | |
| num_classes=num_classes, | |
| pretrained=pretrained, | |
| in_chans=in_chans, | |
| img_size=img_size, | |
| decoder_dim=decoder_dim, | |
| dropout=dropout, | |
| ) | |
| return model | |
| if __name__ == "__main__": | |
| # Smoke test: | |
| # python models/vit.py | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = build_vit( | |
| variant="base", | |
| num_classes=1, | |
| pretrained=False, | |
| img_size=512, | |
| ).to(device) | |
| x = torch.randn(2, 3, 512, 512).to(device) | |
| with torch.no_grad(): | |
| y = model(x) | |
| print("Model:", model.model_name) | |
| print("Input shape:", x.shape) | |
| print("Output shape:", y.shape) | |
| print("Output min/max:", y.min().item(), y.max().item()) | |
| assert y.shape == (2, 1, 512, 512) | |
| print("Smoke test passed.") |