import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange import timm # --------------------------------------------------------- # Basic CNN Blocks # --------------------------------------------------------- class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.block(x) class UpBlock(nn.Module): """ Upsample (bilinear) + concat skip + DoubleConv NO transposed convolutions → no grid artifacts """ def __init__(self, in_ch, skip_ch, out_ch): super().__init__() self.conv = DoubleConv(in_ch + skip_ch, out_ch) def forward(self, x, skip): x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=False) x = torch.cat([x, skip], dim=1) return self.conv(x) # --------------------------------------------------------- # SwinV2 + CNN Decoder # --------------------------------------------------------- class model(nn.Module): def __init__( self, in_channels=3, num_classes=15, freeze_encoder=False, ): super().__init__() # ------------------------------- # Encoder (SwinV2) # ------------------------------- self.encoder = timm.create_model( "swinv2_tiny_window8_256", pretrained=True, features_only=True, out_indices=(0, 1, 2, 3), ) if freeze_encoder: for p in self.encoder.parameters(): p.requires_grad = False # Replace patch embedding to accept custom input channels old_proj = self.encoder.patch_embed.proj self.encoder.patch_embed.proj = nn.Conv2d( in_channels=in_channels, out_channels=old_proj.out_channels, kernel_size=old_proj.kernel_size, stride=old_proj.stride, padding=old_proj.padding, bias=old_proj.bias is not None, ) # Encoder channel sizes c0, c1, c2, c3 = self.encoder.feature_info.channels() # ------------------------------- # CNN Decoder (artifact-free) # ------------------------------- self.up3 = UpBlock(c3, c2, c2) # 1/32 → 1/16 self.up2 = UpBlock(c2, c1, c1) # 1/16 → 1/8 self.up1 = UpBlock(c1, c0, c0) # 1/8 → 1/4 self.refine = nn.Sequential( nn.Conv2d(c0, c0, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(c0, c0, 3, padding=1), nn.ReLU(inplace=True), ) self.head = nn.Conv2d(c0, num_classes, kernel_size=1) # --------------------------------------------------------- # Forward # --------------------------------------------------------- def forward(self, x): f0, f1, f2, f3 = self.encoder(x) # Swin outputs are (B, H, W, C) f0 = rearrange(f0, "b h w c -> b c h w") f1 = rearrange(f1, "b h w c -> b c h w") f2 = rearrange(f2, "b h w c -> b c h w") f3 = rearrange(f3, "b h w c -> b c h w") # Decoder d3 = self.up3(f3, f2) d2 = self.up2(d3, f1) d1 = self.up1(d2, f0) d1 = self.refine(d1) out = F.interpolate( d1, size=x.shape[2:], mode="bilinear", align_corners=False ) return self.head(out)