""" CASWiT Self-Supervised Learning (SSL) Module Implements SimMIM-based self-supervised pre-training for CASWiT using masked image modeling with dual-branch HR/LR processing. """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import UperNetForSemanticSegmentation from transformers.utils import logging as hf_logging hf_logging.set_verbosity_error() hf_logging.disable_progress_bar() def random_masking_with_tokens(x: torch.Tensor, mask_ratio: float = 0.75, mask_token: Optional[torch.Tensor] = None): """ Random masking at token level with learned mask token. Args: x: Input tokens [B, N, C] mask_ratio: Ratio of tokens to mask mask_token: Learnable mask token Returns: x_masked: Masked tokens [B, N, C] mask: Binary mask [B, N] where 0=visible, 1=masked ids_restore: Indices to restore original order """ B, N, C = x.shape len_keep = int(N * (1 - mask_ratio)) noise = torch.rand(B, N, device=x.device) ids_shuffle = torch.argsort(noise, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :len_keep] x_keep = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, C)) if mask_token is None: mask_token = torch.zeros((1, C), device=x.device) m_tok = mask_token.view(1, 1, C).expand(B, N - len_keep, C) x_cat = torch.cat([x_keep, m_tok], dim=1) x_masked = torch.gather(x_cat, 1, ids_restore.unsqueeze(-1).expand(-1, -1, C)) mask = torch.ones(B, N, device=x.device) mask[:, :len_keep] = 0 mask = torch.gather(mask, 1, ids_restore) return x_masked, mask, ids_restore def center_masking_with_tokens(x: torch.Tensor, mask_token: Optional[torch.Tensor] = None, mask_ratio: float = 0.5): """ Deterministic centered square mask. Args: x: Input tokens [B, N, C] mask_token: Learnable mask token mask_ratio: Ratio of tokens to mask Returns: x_masked: Masked tokens [B, N, C] mask: Binary mask [B, N] ids_restore: Indices to restore original order """ B, N, C = x.shape H = W = int(N**0.5) assert H * W == N, "N must be a perfect square" L = int(round(H * (mask_ratio ** 0.5))) start = (H - L) // 2 end = start + L mask_2d = torch.zeros(H, W, device=x.device, dtype=torch.bool) mask_2d[start:end, start:end] = True mask = mask_2d.view(1, -1).expand(B, -1) # (B,N) if mask_token is None: mask_token = torch.zeros(C, device=x.device) mask_token = mask_token.view(-1) x_masked = x * (~mask).unsqueeze(-1) + mask.unsqueeze(-1) * mask_token.view(1, 1, C) ids_restore = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N) return x_masked, mask.to(x_masked.dtype), ids_restore class CrossAttentionBlock(nn.Module): """Simplified cross-attention block for SSL.""" def __init__(self, C_hr, C_lr, num_heads=8, dropout=0.0): super().__init__() self.cross_attn = nn.MultiheadAttention( embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr, dropout=dropout, batch_first=True ) self.norm = nn.LayerNorm(C_hr) self.mlp = nn.Sequential( nn.LayerNorm(C_hr), nn.Linear(C_hr, C_hr * 4), nn.GELU(), nn.Linear(C_hr * 4, C_hr), ) def forward(self, x_hr, x_lr): B, C_hr, H_hr, W_hr = x_hr.shape _, C_lr, H_lr, W_lr = x_lr.shape q = x_hr.flatten(2).transpose(1, 2) # (B,N_hr,C_hr) kv = x_lr.flatten(2).transpose(1, 2) # (B,N_lr,C_lr) attn_out, _ = self.cross_attn(q, kv, kv) y = self.norm(q + attn_out) y = y + self.mlp(y) return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr) class CASWiT_SSL(nn.Module): """ CASWiT Self-Supervised Learning model using SimMIM. Encoder: Dual Swin backbones with cross-attention blocks Decoder: Conv1x1 + PixelShuffle for reconstruction Masking: HR random masking, LR center masking Args: model_name: HuggingFace model identifier mask_ratio_hr: Masking ratio for HR branch mask_ratio_lr: Masking ratio for LR branch patch_size: Patch size for masking encoder_stride: Encoder stride for decoder xa_heads: Number of cross-attention heads per stage """ def __init__(self, model_name: str = "openmmlab/upernet-swin-base", mask_ratio_hr: float = 0.75, mask_ratio_lr: float = 0.5, patch_size: int = 4, encoder_stride: int = 32, xa_heads: Tuple[int, int, int, int] = (8, 8, 8, 8)): super().__init__() self.mask_ratio_hr = mask_ratio_hr self.mask_ratio_lr = mask_ratio_lr self.patch_size = patch_size self.encoder_stride = encoder_stride # Load two UPerNet (Swin) backbones model_hr = UperNetForSemanticSegmentation.from_pretrained( model_name, ignore_mismatched_sizes=True ) model_lr = UperNetForSemanticSegmentation.from_pretrained( model_name, ignore_mismatched_sizes=True ) self.embeddings_hr = model_hr.backbone.embeddings self.encoder_layers_hr = model_hr.backbone.encoder.layers self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms self.embeddings_lr = model_lr.backbone.embeddings self.encoder_layers_lr = model_lr.backbone.encoder.layers self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms # Cross-attention blocks with explicit Swin-Base dims dims = [128, 256, 512, 1024] self.cross_attn_blocks = nn.ModuleList([ CrossAttentionBlock(d, d, num_heads=h) for d, h in zip(dims, xa_heads) ]) # Learnable mask tokens self.mask_token_hr = nn.Parameter(torch.zeros(1, dims[0])) self.mask_token_lr = nn.Parameter(torch.zeros(1, dims[0])) # SimMIM decoder: Conv1×1 → PixelShuffle(stride) self.decoder_conv = None # lazy init after we know C_last self.decoder_shuffle = nn.PixelShuffle(self.encoder_stride) # Store masks for visualization self.last_mask_hr = None self.last_mask_lr = None def _encode(self, x_hr: torch.Tensor, x_lr: torch.Tensor): """Encode with masking and return reconstruction targets.""" B, C, H, W = x_hr.shape target_img = x_hr target_lr = x_lr # Patch embeddings x_hr_seq, _ = self.embeddings_hr(x_hr) # (B, N_hr, C1) x_lr_seq, _ = self.embeddings_lr(x_lr) # (B, N_lr, C1) # Masking x_hr_seq, mask_hr, _ = random_masking_with_tokens( x_hr_seq, self.mask_ratio_hr, self.mask_token_hr ) x_lr_seq, mask_lr, _ = center_masking_with_tokens( x_lr_seq, self.mask_token_lr, mask_ratio=self.mask_ratio_lr ) # Initial spatial dims H_hr = W_hr = int(math.sqrt(x_hr_seq.shape[1])) H_lr = W_lr = int(math.sqrt(x_lr_seq.shape[1])) dims_hr = (H_hr, W_hr) dims_lr = (H_lr, W_lr) # Walk encoder stages with cross attention at each stage for idx, (stage_hr, stage_lr, ca) in enumerate(zip( self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks )): # HR blocks for block in stage_hr.blocks: x_hr_seq = block(x_hr_seq, dims_hr) if isinstance(x_hr_seq, tuple): x_hr_seq = x_hr_seq[0] # LR blocks for block in stage_lr.blocks: x_lr_seq = block(x_lr_seq, dims_lr) if isinstance(x_lr_seq, tuple): x_lr_seq = x_lr_seq[0] # Norms x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq) x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq) # Maps B_, N_hr_, C_hr_ = x_hr_seq.shape B_, N_lr_, C_lr_ = x_lr_seq.shape Hh, Wh = dims_hr Hl, Wl = dims_lr feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B_, C_hr_, Hh, Wh) feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B_, C_lr_, Hl, Wl) # Cross-fuse HR <- LR fused_hr = ca(feat_hr, feat_lr) x_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous() # Downsample to next stage if stage_hr.downsample is not None: x_hr_seq = stage_hr.downsample(x_hr_seq, dims_hr) dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2) if stage_lr.downsample is not None: x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr) dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2) # Last-stage feature map z (B, C_last, H/stride, W/stride) Hs, Ws = dims_hr C_last = x_hr_seq.shape[-1] z = x_hr_seq.transpose(1, 2).contiguous().view(B, C_last, Hs, Ws) # Lazy init decoder conv if self.decoder_conv is None: self.decoder_conv = nn.Conv2d( C_last, (self.encoder_stride ** 2) * 3, kernel_size=1 ).to(z.device) # Reconstruction x_rec = self.decoder_shuffle(self.decoder_conv(z)) # (B,3,H,W) # Convert patch masks to pixel masks Mh = int(math.sqrt(mask_hr.shape[1])) mask_patch_hr = mask_hr.view(B, Mh, Mh) mask_pix_hr = mask_patch_hr.repeat_interleave( self.patch_size, 1 ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() Ml = int(math.sqrt(mask_lr.shape[1])) mask_patch_lr = mask_lr.view(B, Ml, Ml) mask_pix_lr = mask_patch_lr.repeat_interleave( self.patch_size, 1 ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() self.last_mask_hr = mask_patch_hr self.last_mask_lr = mask_patch_lr return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor: """ Forward pass for SSL training. Returns reconstruction loss on masked pixels only. """ x_rec, target_img, mask_pix, _, _ = self._encode(x_hr, x_lr) loss_recon = F.l1_loss(target_img, x_rec, reduction='none') loss = (loss_recon * mask_pix).sum() / (mask_pix.sum() + 1e-6) / target_img.shape[1] return loss @torch.no_grad() def forward_outputs(self, x_hr: torch.Tensor, x_lr: torch.Tensor): """Forward pass returning all outputs for visualization.""" x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr = self._encode(x_hr, x_lr) return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr