CFPVesselSeg / models /vit.py
farrell236's picture
add src
e99a83c
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import timm
except ImportError as e:
raise ImportError(
"timm is required for models/vit.py. Install with: pip install timm"
) from e
class ViTSegmentationModel(nn.Module):
"""
Simple ViT segmentation model using a timm Vision Transformer backbone.
The model:
image -> ViT patch tokens -> reshape to feature map -> conv head -> upsample
Output:
logits of shape [B, num_classes, H, W]
For binary vessel segmentation:
num_classes = 1
For multi-class lesion segmentation:
num_classes = number of lesion/background classes
"""
def __init__(
self,
model_name="vit_base_patch16_224",
num_classes=1,
pretrained=True,
in_chans=3,
img_size=512,
decoder_dim=256,
dropout=0.0,
):
super().__init__()
self.model_name = model_name
self.num_classes = num_classes
self.img_size = img_size
self.backbone = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0,
global_pool="",
in_chans=in_chans,
img_size=img_size,
)
self.embed_dim = self.backbone.num_features
self.patch_size = self.backbone.patch_embed.patch_size
if isinstance(self.patch_size, tuple):
self.patch_size = self.patch_size[0]
self.decoder = nn.Sequential(
nn.Conv2d(self.embed_dim, decoder_dim, kernel_size=1),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
nn.Conv2d(decoder_dim, num_classes, kernel_size=1),
)
def forward_features_as_map(self, x):
"""
Convert ViT patch tokens into a spatial feature map.
Input:
x: [B, C, H, W]
Output:
feature_map: [B, embed_dim, H // patch_size, W // patch_size]
"""
b, _, h, w = x.shape
tokens = self.backbone.forward_features(x)
# Some timm models return a tuple/list. Usually the first item is token features.
if isinstance(tokens, (tuple, list)):
tokens = tokens[0]
# For standard ViT:
# tokens: [B, 1 + num_patches, C], where the first token is CLS.
if tokens.ndim == 3:
expected_num_patches = (h // self.patch_size) * (w // self.patch_size)
if tokens.shape[1] == expected_num_patches + 1:
tokens = tokens[:, 1:, :] # remove CLS token
feature_h = h // self.patch_size
feature_w = w // self.patch_size
tokens = tokens.transpose(1, 2)
feature_map = tokens.reshape(b, self.embed_dim, feature_h, feature_w)
# Some backbones may already return [B, C, H, W].
elif tokens.ndim == 4:
feature_map = tokens
else:
raise RuntimeError(f"Unexpected ViT feature shape: {tokens.shape}")
return feature_map
def forward(self, x):
input_size = x.shape[-2:]
feature_map = self.forward_features_as_map(x)
logits = self.decoder(feature_map)
logits = F.interpolate(
logits,
size=input_size,
mode="bilinear",
align_corners=False,
)
return logits
def build_vit(
variant="base",
num_classes=1,
pretrained=True,
in_chans=3,
img_size=512,
decoder_dim=256,
dropout=0.0,
):
"""
Build a timm ViT segmentation model.
Parameters
----------
variant:
One of:
"tiny"
"small"
"base"
"large"
Or directly pass a timm model name, e.g.:
"vit_base_patch16_224"
"vit_small_patch16_224"
"vit_large_patch16_224"
num_classes:
Number of output channels.
Binary segmentation:
num_classes=1
Multi-class segmentation:
num_classes=N
pretrained:
Whether to load ImageNet-pretrained timm weights.
img_size:
Input image size. For DRIVE, 512 is a reasonable default.
Returns
-------
model:
ViTSegmentationModel
"""
variants = {
"tiny": "vit_tiny_patch16_224",
"small": "vit_small_patch16_224",
"base": "vit_base_patch16_224",
"large": "vit_large_patch16_224",
}
model_name = variants.get(variant, variant)
model = ViTSegmentationModel(
model_name=model_name,
num_classes=num_classes,
pretrained=pretrained,
in_chans=in_chans,
img_size=img_size,
decoder_dim=decoder_dim,
dropout=dropout,
)
return model
if __name__ == "__main__":
# Smoke test:
# python models/vit.py
device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_vit(
variant="base",
num_classes=1,
pretrained=False,
img_size=512,
).to(device)
x = torch.randn(2, 3, 512, 512).to(device)
with torch.no_grad():
y = model(x)
print("Model:", model.model_name)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print("Output min/max:", y.min().item(), y.max().item())
assert y.shape == (2, 1, 512, 512)
print("Smoke test passed.")