Image Segmentation
English
CASWiT / model /CASWiT_ssl.py
antoine.carreaud67
Update with new experiments
d43c376
"""
CASWiT Self-Supervised Learning (SSL) Module
Implements SimMIM-based self-supervised pre-training for CASWiT using
masked image modeling with dual-branch HR/LR processing.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import UperNetForSemanticSegmentation
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()
hf_logging.disable_progress_bar()
def random_masking_with_tokens(x: torch.Tensor, mask_ratio: float = 0.75,
mask_token: Optional[torch.Tensor] = None):
"""
Random masking at token level with learned mask token.
Args:
x: Input tokens [B, N, C]
mask_ratio: Ratio of tokens to mask
mask_token: Learnable mask token
Returns:
x_masked: Masked tokens [B, N, C]
mask: Binary mask [B, N] where 0=visible, 1=masked
ids_restore: Indices to restore original order
"""
B, N, C = x.shape
len_keep = int(N * (1 - mask_ratio))
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_keep = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, C))
if mask_token is None:
mask_token = torch.zeros((1, C), device=x.device)
m_tok = mask_token.view(1, 1, C).expand(B, N - len_keep, C)
x_cat = torch.cat([x_keep, m_tok], dim=1)
x_masked = torch.gather(x_cat, 1, ids_restore.unsqueeze(-1).expand(-1, -1, C))
mask = torch.ones(B, N, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, 1, ids_restore)
return x_masked, mask, ids_restore
def center_masking_with_tokens(x: torch.Tensor, mask_token: Optional[torch.Tensor] = None,
mask_ratio: float = 0.5):
"""
Deterministic centered square mask.
Args:
x: Input tokens [B, N, C]
mask_token: Learnable mask token
mask_ratio: Ratio of tokens to mask
Returns:
x_masked: Masked tokens [B, N, C]
mask: Binary mask [B, N]
ids_restore: Indices to restore original order
"""
B, N, C = x.shape
H = W = int(N**0.5)
assert H * W == N, "N must be a perfect square"
L = int(round(H * (mask_ratio ** 0.5)))
start = (H - L) // 2
end = start + L
mask_2d = torch.zeros(H, W, device=x.device, dtype=torch.bool)
mask_2d[start:end, start:end] = True
mask = mask_2d.view(1, -1).expand(B, -1) # (B,N)
if mask_token is None:
mask_token = torch.zeros(C, device=x.device)
mask_token = mask_token.view(-1)
x_masked = x * (~mask).unsqueeze(-1) + mask.unsqueeze(-1) * mask_token.view(1, 1, C)
ids_restore = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N)
return x_masked, mask.to(x_masked.dtype), ids_restore
class CrossAttentionBlock(nn.Module):
"""Simplified cross-attention block for SSL."""
def __init__(self, C_hr, C_lr, num_heads=8, dropout=0.0):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
dropout=dropout, batch_first=True
)
self.norm = nn.LayerNorm(C_hr)
self.mlp = nn.Sequential(
nn.LayerNorm(C_hr),
nn.Linear(C_hr, C_hr * 4),
nn.GELU(),
nn.Linear(C_hr * 4, C_hr),
)
def forward(self, x_hr, x_lr):
B, C_hr, H_hr, W_hr = x_hr.shape
_, C_lr, H_lr, W_lr = x_lr.shape
q = x_hr.flatten(2).transpose(1, 2) # (B,N_hr,C_hr)
kv = x_lr.flatten(2).transpose(1, 2) # (B,N_lr,C_lr)
attn_out, _ = self.cross_attn(q, kv, kv)
y = self.norm(q + attn_out)
y = y + self.mlp(y)
return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
class CASWiT_SSL(nn.Module):
"""
CASWiT Self-Supervised Learning model using SimMIM.
Encoder: Dual Swin backbones with cross-attention blocks
Decoder: Conv1x1 + PixelShuffle for reconstruction
Masking: HR random masking, LR center masking
Args:
model_name: HuggingFace model identifier
mask_ratio_hr: Masking ratio for HR branch
mask_ratio_lr: Masking ratio for LR branch
patch_size: Patch size for masking
encoder_stride: Encoder stride for decoder
xa_heads: Number of cross-attention heads per stage
"""
def __init__(self, model_name: str = "openmmlab/upernet-swin-base",
mask_ratio_hr: float = 0.75, mask_ratio_lr: float = 0.5,
patch_size: int = 4, encoder_stride: int = 32,
xa_heads: Tuple[int, int, int, int] = (8, 8, 8, 8)):
super().__init__()
self.mask_ratio_hr = mask_ratio_hr
self.mask_ratio_lr = mask_ratio_lr
self.patch_size = patch_size
self.encoder_stride = encoder_stride
# Load two UPerNet (Swin) backbones
model_hr = UperNetForSemanticSegmentation.from_pretrained(
model_name, ignore_mismatched_sizes=True
)
model_lr = UperNetForSemanticSegmentation.from_pretrained(
model_name, ignore_mismatched_sizes=True
)
self.embeddings_hr = model_hr.backbone.embeddings
self.encoder_layers_hr = model_hr.backbone.encoder.layers
self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
self.embeddings_lr = model_lr.backbone.embeddings
self.encoder_layers_lr = model_lr.backbone.encoder.layers
self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
# Cross-attention blocks with explicit Swin-Base dims
dims = [128, 256, 512, 1024]
self.cross_attn_blocks = nn.ModuleList([
CrossAttentionBlock(d, d, num_heads=h) for d, h in zip(dims, xa_heads)
])
# Learnable mask tokens
self.mask_token_hr = nn.Parameter(torch.zeros(1, dims[0]))
self.mask_token_lr = nn.Parameter(torch.zeros(1, dims[0]))
# SimMIM decoder: Conv1×1 → PixelShuffle(stride)
self.decoder_conv = None # lazy init after we know C_last
self.decoder_shuffle = nn.PixelShuffle(self.encoder_stride)
# Store masks for visualization
self.last_mask_hr = None
self.last_mask_lr = None
def _encode(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
"""Encode with masking and return reconstruction targets."""
B, C, H, W = x_hr.shape
target_img = x_hr
target_lr = x_lr
# Patch embeddings
x_hr_seq, _ = self.embeddings_hr(x_hr) # (B, N_hr, C1)
x_lr_seq, _ = self.embeddings_lr(x_lr) # (B, N_lr, C1)
# Masking
x_hr_seq, mask_hr, _ = random_masking_with_tokens(
x_hr_seq, self.mask_ratio_hr, self.mask_token_hr
)
x_lr_seq, mask_lr, _ = center_masking_with_tokens(
x_lr_seq, self.mask_token_lr, mask_ratio=self.mask_ratio_lr
)
# Initial spatial dims
H_hr = W_hr = int(math.sqrt(x_hr_seq.shape[1]))
H_lr = W_lr = int(math.sqrt(x_lr_seq.shape[1]))
dims_hr = (H_hr, W_hr)
dims_lr = (H_lr, W_lr)
# Walk encoder stages with cross attention at each stage
for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
)):
# HR blocks
for block in stage_hr.blocks:
x_hr_seq = block(x_hr_seq, dims_hr)
if isinstance(x_hr_seq, tuple):
x_hr_seq = x_hr_seq[0]
# LR blocks
for block in stage_lr.blocks:
x_lr_seq = block(x_lr_seq, dims_lr)
if isinstance(x_lr_seq, tuple):
x_lr_seq = x_lr_seq[0]
# Norms
x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
# Maps
B_, N_hr_, C_hr_ = x_hr_seq.shape
B_, N_lr_, C_lr_ = x_lr_seq.shape
Hh, Wh = dims_hr
Hl, Wl = dims_lr
feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B_, C_hr_, Hh, Wh)
feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B_, C_lr_, Hl, Wl)
# Cross-fuse HR <- LR
fused_hr = ca(feat_hr, feat_lr)
x_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
# Downsample to next stage
if stage_hr.downsample is not None:
x_hr_seq = stage_hr.downsample(x_hr_seq, dims_hr)
dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
if stage_lr.downsample is not None:
x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
# Last-stage feature map z (B, C_last, H/stride, W/stride)
Hs, Ws = dims_hr
C_last = x_hr_seq.shape[-1]
z = x_hr_seq.transpose(1, 2).contiguous().view(B, C_last, Hs, Ws)
# Lazy init decoder conv
if self.decoder_conv is None:
self.decoder_conv = nn.Conv2d(
C_last, (self.encoder_stride ** 2) * 3, kernel_size=1
).to(z.device)
# Reconstruction
x_rec = self.decoder_shuffle(self.decoder_conv(z)) # (B,3,H,W)
# Convert patch masks to pixel masks
Mh = int(math.sqrt(mask_hr.shape[1]))
mask_patch_hr = mask_hr.view(B, Mh, Mh)
mask_pix_hr = mask_patch_hr.repeat_interleave(
self.patch_size, 1
).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
Ml = int(math.sqrt(mask_lr.shape[1]))
mask_patch_lr = mask_lr.view(B, Ml, Ml)
mask_pix_lr = mask_patch_lr.repeat_interleave(
self.patch_size, 1
).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
self.last_mask_hr = mask_patch_hr
self.last_mask_lr = mask_patch_lr
return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
"""
Forward pass for SSL training.
Returns reconstruction loss on masked pixels only.
"""
x_rec, target_img, mask_pix, _, _ = self._encode(x_hr, x_lr)
loss_recon = F.l1_loss(target_img, x_rec, reduction='none')
loss = (loss_recon * mask_pix).sum() / (mask_pix.sum() + 1e-6) / target_img.shape[1]
return loss
@torch.no_grad()
def forward_outputs(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
"""Forward pass returning all outputs for visualization."""
x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr = self._encode(x_hr, x_lr)
return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr