import math import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor, nn from torch.nn.utils import parametrize from transformers import PreTrainedModel from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse from transformers.utils.backbone_utils import load_backbone from .configuration import LSPDetrConfig class MLP(nn.Sequential): """Very simple multi-layer perceptron.""" def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act_layer: type[nn.Module] = nn.GELU, dropout: float = 0.0, ) -> None: assert num_layers > 1 layers = [] h = [hidden_dim] * (num_layers - 1) for n, k in zip([input_dim, *h], h, strict=False): layers.append(nn.Linear(n, k)) layers.append(act_layer()) if dropout > 0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(hidden_dim, output_dim)) super().__init__(*layers) class FeedForward(nn.Module): """FeedForward module. Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py """ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None: """Initialize the FeedForward module. Args: dim (int): Input dimension. hidden_dim (int): Hidden dimension of the feedforward layer. multiple_of (int): Value to ensure hidden dimension is a multiple of this value. """ super().__init__() hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x: Tensor) -> Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor: """Taken from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py.""" freqs_x = [] freqs_y = [] freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim)) for _ in range(num_heads): angles = torch.rand(1) * 2 * torch.pi fx = torch.cat( [freqs * torch.cos(angles), freqs * torch.cos(torch.pi / 2 + angles)], dim=-1, ) fy = torch.cat( [freqs * torch.sin(angles), freqs * torch.sin(torch.pi / 2 + angles)], dim=-1, ) freqs_x.append(fx) freqs_y.append(fy) freqs_x = torch.stack(freqs_x, dim=0) freqs_y = torch.stack(freqs_y, dim=0) return torch.stack([freqs_x, freqs_y], dim=0) class Skew(nn.Module): """Skew-symmetric matrix parameterization.""" def forward(self, x: Tensor) -> Tensor: a = x.triu(1) return a - a.transpose(-1, -2) def right_inverse(self, x: Tensor) -> Tensor: return x.triu(1) class CayleySTRING(nn.Module): """Implements the Cayley-STRING positional encoding. Based on "Learning the RoPEs: Better 2D and 3D Position Encodings with STRING" (https://arxiv.org/abs/2502.02562). Applies RoPE followed by multiplication with a learnable orthogonal matrix P parameterized by the Cayley transform: P = (I - S)(I + S)^-1, where S is a learnable skew-symmetric matrix. Args: dim (int): The feature dimension of the input tensor. Must be even. max_seq_len (int): The maximum sequence length. base (int): The base value for the RoPE frequency calculation. Defaults to 10000. pos_dim (int): The dimensionality of the position vectors (e.g., 1 for 1D, 2 for 2D). Defaults to 1. """ def __init__( self, dim: int, num_heads: int, pos_dim: int = 2, theta: float = 100.0 ) -> None: super().__init__() assert dim % num_heads == 0, "Dimension must be divisible by num_heads." head_dim = dim // num_heads self.freqs = nn.Parameter(init_freqs(head_dim, num_heads, pos_dim, theta)) self.S = nn.Parameter(torch.zeros(head_dim, head_dim)) parametrize.register_parametrization(self, "S", Skew()) self.register_buffer("I", torch.eye(head_dim), persistent=False) self.init_weights() def init_weights(self) -> None: self.S = nn.init.kaiming_uniform_(self.S, a=math.sqrt(5)) @parametrize.cached() @torch.autocast("cuda", enabled=False) def forward(self, x: Tensor, positions: Tensor) -> Tensor: """Apply Cayley-STRING positional encoding. Args: x ([b, h, n, d]): Input tensor. positions ([b, n, pos_dim]): Positions tensor. """ # Compute (I + S)^-1 @ x y = torch.linalg.solve( self.I + self.S, rearrange(x.float(), "b h n d -> h d (b n)") ) # change of basis px = torch.matmul(self.I - self.S, y) px = rearrange(px, "h d (b n) -> b h n d", b=x.size(0)).contiguous() # apply RoPE-Mixed angles = torch.einsum("bnk,khc->bhnc", positions, self.freqs) freqs_cis = torch.polar(torch.ones_like(angles), angles) px_ = torch.view_as_complex(rearrange(px, "... (d two) -> ... d two", two=2)) out = rearrange(torch.view_as_real(px_ * freqs_cis), "... d two -> ... (d two)") return out.type_as(x) def maybe_pad(x: Tensor, window_size: int) -> Tensor: h, w = x.shape[1:3] pad_right = (window_size - w % window_size) % window_size pad_bottom = (window_size - h % window_size) % window_size return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom)) @torch.autocast("cuda", enabled=False) def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tensor: pos = pos.sigmoid() h, w = pos.shape[1:3] anchor_x = torch.arange(w, dtype=torch.float32, device=pos.device) * step_x anchor_y = torch.arange(h, dtype=torch.float32, device=pos.device) * step_y absolute_x = pos[..., 0] * step_x + anchor_x absolute_y = pos[..., 1] * step_y + anchor_y.unsqueeze(1) return torch.stack((absolute_x, absolute_y), dim=-1) def get_mask_windows( height: int, width: int, window_size: int, shift_size: int, device: torch.device ) -> Tensor: # Create indices for height and width regions h_idx = torch.zeros(height, dtype=torch.long, device=device) h_idx[height - window_size : height - shift_size] = 1 h_idx[height - shift_size :] = 2 w_idx = torch.zeros(width, dtype=torch.long, device=device) w_idx[width - window_size : width - shift_size] = 1 w_idx[width - shift_size :] = 2 # Calculate region index for each pixel using broadcasting mask = h_idx.unsqueeze(1) * 3 + w_idx.unsqueeze(0) mask_windows = window_partition(mask[None, ..., None], window_size) return rearrange(mask_windows, "n w1 w2 1 -> n (w1 w2)") class WindowCrossAttention(nn.Module): def __init__( self, dim: int, src_dim: int, tgt_window_size: int, src_window_size: int, num_heads: int, src_shift_size: int = 0, tgt_shift_size: int = 0, dropout: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads self.tgt_window_size = tgt_window_size self.src_window_size = src_window_size self.src_shift_size = src_shift_size self.tgt_shift_size = tgt_shift_size self.dropout = dropout self.pe = CayleySTRING(dim, num_heads) self.query = nn.Linear(dim, dim, bias=False) self.kv = nn.Linear(src_dim, dim * 2, bias=False) self.wo = nn.Linear(dim, dim, bias=False) def get_attn_mask( self, height: int, width: int, key_height: int, key_width: int, device: torch.device, dtype: torch.dtype, ) -> Tensor | None: if self.tgt_shift_size == 0: return None query_mask = get_mask_windows( height, width, self.tgt_window_size, self.tgt_shift_size, device ) key_mask = get_mask_windows( key_height, key_width, self.src_window_size, self.src_shift_size, device ) attn_mask = query_mask.unsqueeze(2) - key_mask.unsqueeze(1) return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf) def forward( self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor ) -> Tensor: b, h, w, c = tgt.shape # pad to multiples of window size tgt = maybe_pad(tgt, self.tgt_window_size) src = maybe_pad(src, self.src_window_size) tgt_coords = maybe_pad(tgt_coords, self.tgt_window_size) src_coord = maybe_pad(src_coord, self.src_window_size) h_pad, w_pad = tgt.shape[1:3] src_h, src_w = src.shape[1:3] # cyclic shift if self.tgt_shift_size > 0: tgt = tgt.roll( shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2) ) tgt_coords = tgt_coords.roll( shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2) ) if self.src_shift_size > 0: src = src.roll( shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2) ) src_coord = src_coord.roll( shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2) ) # partition windows tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2) src = window_partition(src, self.src_window_size).flatten(1, 2) tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2) src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2) attn_mask = self.get_attn_mask( h_pad, w_pad, src_h, src_w, tgt.device, tgt.dtype ) if attn_mask is not None: attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads) # W-MCA/SW-MCA q = rearrange(self.query(tgt), "b n (h d) -> b h n d", h=self.num_heads) k, v = rearrange( self.kv(src), "b n (two h d) -> two b h n d", two=2, h=self.num_heads ) x = F.scaled_dot_product_attention( query=self.pe(q, tgt_coords), key=self.pe(k, src_coord), value=v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, ) tgt = self.wo(rearrange(x, "b h n d -> b n (h d)")) # merge windows tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c) tgt = window_reverse(tgt, self.tgt_window_size, h_pad, w_pad) # reverse cyclic shift if self.tgt_shift_size > 0: tgt = torch.roll( tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2) ) return tgt[:, :h, :w, :].contiguous() # remove padding class WindowSelfAttention(nn.Module): def __init__( self, dim: int, window_size: int, num_heads: int, shift_size: int = 0, dropout: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.dropout = dropout self.pe = CayleySTRING(dim, num_heads) self.qkv = nn.Linear(dim, dim * 3, bias=False) self.wo = nn.Linear(dim, dim, bias=False) def get_attn_mask( self, height: int, width: int, device: torch.device, dtype: torch.dtype ) -> Tensor | None: if self.shift_size == 0: return None mask_windows = get_mask_windows( height, width, self.window_size, self.shift_size, device ) # Calculate the attention mask based on window differences attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(1) return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf) def forward(self, x: Tensor, coords: Tensor) -> Tensor: """Forward function for Window Self-Attention. Args: x ([b, h, w, c]): Hidden states. coords ([b, h, w, 2]): Absolute positions. """ b, h, w, c = x.shape # pad to multiples of window size x = maybe_pad(x, self.window_size) coords = maybe_pad(coords, self.window_size) h_pad, w_pad = x.shape[1:3] # cyclic shift if self.shift_size > 0: x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) coords = coords.roll( shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) ) # partition windows x = window_partition(x, self.window_size).flatten(1, 2) coords = window_partition(coords, self.window_size).flatten(1, 2) attn_mask = self.get_attn_mask(h_pad, w_pad, x.device, x.dtype) if attn_mask is not None: attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads) # W-MSA/SW-MSA q, k, v = rearrange( self.qkv(x), "b n (three h d) -> three b h n d", three=3, h=self.num_heads ) x = F.scaled_dot_product_attention( query=self.pe(q, coords), key=self.pe(k, coords), value=v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, ) x = self.wo(rearrange(x, "b h n d -> b n (h d)")) # merge windows x = x.view(-1, self.window_size, self.window_size, c) x = window_reverse(x, self.window_size, h_pad, w_pad) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) return x[:, :h, :w, :].contiguous() # remove padding class Block(nn.Module): def __init__( self, dim: int, src_dim: int, num_heads: int, window_size: int, tgt_window_size: int, src_window_size: int, shift_size: int = 0, tgt_shift_size: int = 0, src_shift_size: int = 0, dropout: float = 0.1, ) -> None: super().__init__() self.cross_attention = WindowCrossAttention( dim, src_dim, num_heads=num_heads, tgt_window_size=tgt_window_size, src_window_size=src_window_size, tgt_shift_size=tgt_shift_size, src_shift_size=src_shift_size, dropout=dropout, ) self.cross_attention_norm = nn.LayerNorm(dim) self.cross_attention_dropout = nn.Dropout(dropout) self.self_attention = WindowSelfAttention( dim, window_size, num_heads, shift_size, dropout=dropout ) self.self_attention_norm = nn.LayerNorm(dim) self.self_attention_dropout = nn.Dropout(dropout) self.ffn = FeedForward(dim, dim * 4) self.ffn_norm = nn.LayerNorm(dim) self.ffn_dropout = nn.Dropout(dropout) def forward( self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords ) -> Tensor: x = self.self_attention(tgt, tgt_coords) tgt = self.self_attention_norm(tgt + self.self_attention_dropout(x)) x = self.cross_attention(tgt, src, tgt_coords, src_coords) tgt = self.cross_attention_norm(tgt + self.cross_attention_dropout(x)) return self.ffn_norm(tgt + self.ffn_dropout(self.ffn(tgt))) class Stage(nn.Module): def __init__( self, dim: int, src_dim: int, depth: int, num_heads: int, window_size: int, tgt_window_size: int, src_window_size: int, dropout: float = 0.0, ) -> None: super().__init__() self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=dim, src_dim=src_dim, num_heads=num_heads, window_size=window_size, tgt_window_size=tgt_window_size, src_window_size=src_window_size, shift_size=0 if i % 2 == 0 else window_size // 2, tgt_shift_size=0 if i % 2 == 0 else tgt_window_size // 2, src_shift_size=0 if i % 2 == 0 else src_window_size // 2, dropout=dropout, ) self.blocks.append(block) def forward( self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor ) -> Tensor: for block in self.blocks: tgt = block(tgt, src, tgt_coords, src_coords) return tgt class LSPTransformer(nn.Module): def __init__(self, config: LSPDetrConfig, feature_channels: list[int]) -> None: super().__init__() self.query_block_size = config.query_block_size self.num_radial_distances = config.num_radial_distances self.stages = nn.ModuleList() for i, depth in enumerate(config.depths): stage = Stage( dim=config.dim, src_dim=feature_channels[i], depth=depth, num_heads=config.num_heads, window_size=config.window_size, tgt_window_size=config.tgt_window_sizes[i], src_window_size=config.src_window_sizes[i], dropout=config.dropout, ) self.stages.append(stage) self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in feature_channels) # output heads self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False) self.point_head = MLP(config.dim, config.dim, 2, 2) self.radial_distances_head = MLP( config.dim, config.dim, config.num_radial_distances, 2 ) self.init_weights() def init_weights(self) -> None: # initialize regression layers nn.init.constant_(self.point_head[-1].weight, 0.0) nn.init.constant_(self.point_head[-1].bias, 0.0) def forward( self, tgt: Tensor, ref_points: Tensor, features: list[Tensor], height: int, width: int, ) -> dict[str, Tensor | list[dict[str, Tensor]]]: src = [] src_coords = [] for i, feature in enumerate(features): b, _, h, w = feature.shape coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device) src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c"))) src_coords.append( relative_to_absolute_pos( coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h) ) ) logits_list: list[Tensor] = [] ref_points_list: list[Tensor] = [] radial_distances_list: list[Tensor] = [] new_ref_points = ref_points.clone() # for look forward twice for i, stage in enumerate(self.stages): tgt = stage( tgt=tgt, src=src[i], tgt_coords=relative_to_absolute_pos( ref_points, self.query_block_size, self.query_block_size ), src_coords=src_coords[i], ) # output heads delta_point = self.point_head(tgt) radial_distances = self.radial_distances_head(tgt) logits = self.class_head(tgt) ref_points_list.append( relative_to_absolute_pos( new_ref_points + delta_point, step_x=self.query_block_size / width, step_y=self.query_block_size / height, ).flatten(1, 2) ) logits_list.append(logits.flatten(1, 2)) radial_distances_list.append(radial_distances.flatten(1, 2)) new_ref_points = ref_points + delta_point ref_points = new_ref_points.detach() return { "logits": logits_list[-1], "points": ref_points_list[-1], "radial_distances": radial_distances_list[-1], "absolute_points": relative_to_absolute_pos( ref_points, self.query_block_size, self.query_block_size ).flatten(1, 2), "aux_outputs": [ { "logits": a, "points": b, "radial_distances": c, } for a, b, c in zip( logits_list[:-1], ref_points_list[:-1], radial_distances_list[:-1], strict=True, ) ], } class FeatureSampling(nn.Module): def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() self.reduction = nn.Linear(in_dim, out_dim, bias=False) self.norm = nn.LayerNorm(out_dim) def forward(self, points: Tensor, feature: Tensor) -> Tensor: x = F.grid_sample(feature, points * 2 - 1, align_corners=False) return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c"))) class LSPDetrModel(PreTrainedModel): config_class = LSPDetrConfig def __init__(self, config: LSPDetrConfig) -> None: super().__init__(config) self.query_block_size = config.query_block_size self.backbone = load_backbone(config) _, *feature_channels, neck = self.backbone.num_features self.feature_sampling = FeatureSampling(neck, config.dim) self.decode_head = LSPTransformer(config, feature_channels[::-1]) def forward(self, pixel_values: Tensor) -> dict[str, Tensor]: b, _, h, w = pixel_values.shape *features, neck = self.backbone(pixel_values).feature_maps ref_points = torch.zeros( b, math.ceil(h / self.query_block_size), math.ceil(w / self.query_block_size), 2, dtype=torch.float32, device=neck.device, ) # center positions tgt = self.feature_sampling( relative_to_absolute_pos( ref_points, self.query_block_size, self.query_block_size ), neck, ) return self.decode_head(tgt, ref_points, features[::-1], h, w)