| | """ |
| | CASWiT Self-Supervised Learning (SSL) Module |
| | |
| | Implements SimMIM-based self-supervised pre-training for CASWiT using |
| | masked image modeling with dual-branch HR/LR processing. |
| | """ |
| |
|
| | import math |
| | from typing import Optional, Tuple |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import UperNetForSemanticSegmentation |
| | from transformers.utils import logging as hf_logging |
| |
|
| | hf_logging.set_verbosity_error() |
| | hf_logging.disable_progress_bar() |
| |
|
| |
|
| |
|
| | def random_masking_with_tokens(x: torch.Tensor, mask_ratio: float = 0.75, |
| | mask_token: Optional[torch.Tensor] = None): |
| | """ |
| | Random masking at token level with learned mask token. |
| | |
| | Args: |
| | x: Input tokens [B, N, C] |
| | mask_ratio: Ratio of tokens to mask |
| | mask_token: Learnable mask token |
| | |
| | Returns: |
| | x_masked: Masked tokens [B, N, C] |
| | mask: Binary mask [B, N] where 0=visible, 1=masked |
| | ids_restore: Indices to restore original order |
| | """ |
| | B, N, C = x.shape |
| | len_keep = int(N * (1 - mask_ratio)) |
| |
|
| | noise = torch.rand(B, N, device=x.device) |
| | ids_shuffle = torch.argsort(noise, dim=1) |
| | ids_restore = torch.argsort(ids_shuffle, dim=1) |
| | ids_keep = ids_shuffle[:, :len_keep] |
| |
|
| | x_keep = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, C)) |
| |
|
| | if mask_token is None: |
| | mask_token = torch.zeros((1, C), device=x.device) |
| | m_tok = mask_token.view(1, 1, C).expand(B, N - len_keep, C) |
| |
|
| | x_cat = torch.cat([x_keep, m_tok], dim=1) |
| | x_masked = torch.gather(x_cat, 1, ids_restore.unsqueeze(-1).expand(-1, -1, C)) |
| |
|
| | mask = torch.ones(B, N, device=x.device) |
| | mask[:, :len_keep] = 0 |
| | mask = torch.gather(mask, 1, ids_restore) |
| | return x_masked, mask, ids_restore |
| |
|
| |
|
| | def center_masking_with_tokens(x: torch.Tensor, mask_token: Optional[torch.Tensor] = None, |
| | mask_ratio: float = 0.5): |
| | """ |
| | Deterministic centered square mask. |
| | |
| | Args: |
| | x: Input tokens [B, N, C] |
| | mask_token: Learnable mask token |
| | mask_ratio: Ratio of tokens to mask |
| | |
| | Returns: |
| | x_masked: Masked tokens [B, N, C] |
| | mask: Binary mask [B, N] |
| | ids_restore: Indices to restore original order |
| | """ |
| | B, N, C = x.shape |
| | H = W = int(N**0.5) |
| | assert H * W == N, "N must be a perfect square" |
| | L = int(round(H * (mask_ratio ** 0.5))) |
| | start = (H - L) // 2 |
| | end = start + L |
| |
|
| | mask_2d = torch.zeros(H, W, device=x.device, dtype=torch.bool) |
| | mask_2d[start:end, start:end] = True |
| | mask = mask_2d.view(1, -1).expand(B, -1) |
| |
|
| | if mask_token is None: |
| | mask_token = torch.zeros(C, device=x.device) |
| | mask_token = mask_token.view(-1) |
| |
|
| | x_masked = x * (~mask).unsqueeze(-1) + mask.unsqueeze(-1) * mask_token.view(1, 1, C) |
| | ids_restore = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N) |
| | return x_masked, mask.to(x_masked.dtype), ids_restore |
| |
|
| |
|
| | class CrossAttentionBlock(nn.Module): |
| | """Simplified cross-attention block for SSL.""" |
| | def __init__(self, C_hr, C_lr, num_heads=8, dropout=0.0): |
| | super().__init__() |
| | self.cross_attn = nn.MultiheadAttention( |
| | embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr, |
| | dropout=dropout, batch_first=True |
| | ) |
| | self.norm = nn.LayerNorm(C_hr) |
| | self.mlp = nn.Sequential( |
| | nn.LayerNorm(C_hr), |
| | nn.Linear(C_hr, C_hr * 4), |
| | nn.GELU(), |
| | nn.Linear(C_hr * 4, C_hr), |
| | ) |
| |
|
| | def forward(self, x_hr, x_lr): |
| | B, C_hr, H_hr, W_hr = x_hr.shape |
| | _, C_lr, H_lr, W_lr = x_lr.shape |
| | q = x_hr.flatten(2).transpose(1, 2) |
| | kv = x_lr.flatten(2).transpose(1, 2) |
| | attn_out, _ = self.cross_attn(q, kv, kv) |
| | y = self.norm(q + attn_out) |
| | y = y + self.mlp(y) |
| | return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr) |
| |
|
| |
|
| | class CASWiT_SSL(nn.Module): |
| | """ |
| | CASWiT Self-Supervised Learning model using SimMIM. |
| | |
| | Encoder: Dual Swin backbones with cross-attention blocks |
| | Decoder: Conv1x1 + PixelShuffle for reconstruction |
| | Masking: HR random masking, LR center masking |
| | |
| | Args: |
| | model_name: HuggingFace model identifier |
| | mask_ratio_hr: Masking ratio for HR branch |
| | mask_ratio_lr: Masking ratio for LR branch |
| | patch_size: Patch size for masking |
| | encoder_stride: Encoder stride for decoder |
| | xa_heads: Number of cross-attention heads per stage |
| | """ |
| | def __init__(self, model_name: str = "openmmlab/upernet-swin-base", |
| | mask_ratio_hr: float = 0.75, mask_ratio_lr: float = 0.5, |
| | patch_size: int = 4, encoder_stride: int = 32, |
| | xa_heads: Tuple[int, int, int, int] = (8, 8, 8, 8)): |
| | super().__init__() |
| | self.mask_ratio_hr = mask_ratio_hr |
| | self.mask_ratio_lr = mask_ratio_lr |
| | self.patch_size = patch_size |
| | self.encoder_stride = encoder_stride |
| |
|
| | |
| | model_hr = UperNetForSemanticSegmentation.from_pretrained( |
| | model_name, ignore_mismatched_sizes=True |
| | ) |
| | model_lr = UperNetForSemanticSegmentation.from_pretrained( |
| | model_name, ignore_mismatched_sizes=True |
| | ) |
| |
|
| | self.embeddings_hr = model_hr.backbone.embeddings |
| | self.encoder_layers_hr = model_hr.backbone.encoder.layers |
| | self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms |
| |
|
| | self.embeddings_lr = model_lr.backbone.embeddings |
| | self.encoder_layers_lr = model_lr.backbone.encoder.layers |
| | self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms |
| |
|
| | |
| | dims = [128, 256, 512, 1024] |
| | self.cross_attn_blocks = nn.ModuleList([ |
| | CrossAttentionBlock(d, d, num_heads=h) for d, h in zip(dims, xa_heads) |
| | ]) |
| |
|
| | |
| | self.mask_token_hr = nn.Parameter(torch.zeros(1, dims[0])) |
| | self.mask_token_lr = nn.Parameter(torch.zeros(1, dims[0])) |
| |
|
| | |
| | self.decoder_conv = None |
| | self.decoder_shuffle = nn.PixelShuffle(self.encoder_stride) |
| |
|
| | |
| | self.last_mask_hr = None |
| | self.last_mask_lr = None |
| |
|
| | def _encode(self, x_hr: torch.Tensor, x_lr: torch.Tensor): |
| | """Encode with masking and return reconstruction targets.""" |
| | B, C, H, W = x_hr.shape |
| | target_img = x_hr |
| | target_lr = x_lr |
| |
|
| | |
| | x_hr_seq, _ = self.embeddings_hr(x_hr) |
| | x_lr_seq, _ = self.embeddings_lr(x_lr) |
| |
|
| | |
| | x_hr_seq, mask_hr, _ = random_masking_with_tokens( |
| | x_hr_seq, self.mask_ratio_hr, self.mask_token_hr |
| | ) |
| | x_lr_seq, mask_lr, _ = center_masking_with_tokens( |
| | x_lr_seq, self.mask_token_lr, mask_ratio=self.mask_ratio_lr |
| | ) |
| |
|
| | |
| | H_hr = W_hr = int(math.sqrt(x_hr_seq.shape[1])) |
| | H_lr = W_lr = int(math.sqrt(x_lr_seq.shape[1])) |
| | dims_hr = (H_hr, W_hr) |
| | dims_lr = (H_lr, W_lr) |
| |
|
| | |
| | for idx, (stage_hr, stage_lr, ca) in enumerate(zip( |
| | self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks |
| | )): |
| | |
| | for block in stage_hr.blocks: |
| | x_hr_seq = block(x_hr_seq, dims_hr) |
| | if isinstance(x_hr_seq, tuple): |
| | x_hr_seq = x_hr_seq[0] |
| | |
| | for block in stage_lr.blocks: |
| | x_lr_seq = block(x_lr_seq, dims_lr) |
| | if isinstance(x_lr_seq, tuple): |
| | x_lr_seq = x_lr_seq[0] |
| |
|
| | |
| | x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq) |
| | x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq) |
| |
|
| | |
| | B_, N_hr_, C_hr_ = x_hr_seq.shape |
| | B_, N_lr_, C_lr_ = x_lr_seq.shape |
| | Hh, Wh = dims_hr |
| | Hl, Wl = dims_lr |
| | feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B_, C_hr_, Hh, Wh) |
| | feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B_, C_lr_, Hl, Wl) |
| |
|
| | |
| | fused_hr = ca(feat_hr, feat_lr) |
| | x_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous() |
| |
|
| | |
| | if stage_hr.downsample is not None: |
| | x_hr_seq = stage_hr.downsample(x_hr_seq, dims_hr) |
| | dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2) |
| | if stage_lr.downsample is not None: |
| | x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr) |
| | dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2) |
| |
|
| | |
| | Hs, Ws = dims_hr |
| | C_last = x_hr_seq.shape[-1] |
| | z = x_hr_seq.transpose(1, 2).contiguous().view(B, C_last, Hs, Ws) |
| |
|
| | |
| | if self.decoder_conv is None: |
| | self.decoder_conv = nn.Conv2d( |
| | C_last, (self.encoder_stride ** 2) * 3, kernel_size=1 |
| | ).to(z.device) |
| |
|
| | |
| | x_rec = self.decoder_shuffle(self.decoder_conv(z)) |
| |
|
| | |
| | Mh = int(math.sqrt(mask_hr.shape[1])) |
| | mask_patch_hr = mask_hr.view(B, Mh, Mh) |
| | mask_pix_hr = mask_patch_hr.repeat_interleave( |
| | self.patch_size, 1 |
| | ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() |
| |
|
| | Ml = int(math.sqrt(mask_lr.shape[1])) |
| | mask_patch_lr = mask_lr.view(B, Ml, Ml) |
| | mask_pix_lr = mask_patch_lr.repeat_interleave( |
| | self.patch_size, 1 |
| | ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() |
| |
|
| | self.last_mask_hr = mask_patch_hr |
| | self.last_mask_lr = mask_patch_lr |
| |
|
| | return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr |
| |
|
| | def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Forward pass for SSL training. |
| | |
| | Returns reconstruction loss on masked pixels only. |
| | """ |
| | x_rec, target_img, mask_pix, _, _ = self._encode(x_hr, x_lr) |
| | loss_recon = F.l1_loss(target_img, x_rec, reduction='none') |
| | loss = (loss_recon * mask_pix).sum() / (mask_pix.sum() + 1e-6) / target_img.shape[1] |
| | return loss |
| |
|
| | @torch.no_grad() |
| | def forward_outputs(self, x_hr: torch.Tensor, x_lr: torch.Tensor): |
| | """Forward pass returning all outputs for visualization.""" |
| | x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr = self._encode(x_hr, x_lr) |
| | return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr |
| |
|
| |
|