LSP-DETR / modeling.py
matejpekar's picture
Upload model
f6fbbfa verified
raw
history blame
22.6 kB
import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor, nn
from torch.nn.utils import parametrize
from transformers import PreTrainedModel
from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse
from transformers.utils.backbone_utils import load_backbone
from .configuration import LSPDetrConfig
class MLP(nn.Sequential):
"""Very simple multi-layer perceptron."""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
act_layer: type[nn.Module] = nn.GELU,
dropout: float = 0.0,
) -> None:
assert num_layers > 1
layers = []
h = [hidden_dim] * (num_layers - 1)
for n, k in zip([input_dim, *h], h, strict=False):
layers.append(nn.Linear(n, k))
layers.append(act_layer())
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(hidden_dim, output_dim))
super().__init__(*layers)
class FeedForward(nn.Module):
"""FeedForward module.
Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py
"""
def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None:
"""Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
"""Taken from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py."""
freqs_x = []
freqs_y = []
freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim))
for _ in range(num_heads):
angles = torch.rand(1) * 2 * torch.pi
fx = torch.cat(
[freqs * torch.cos(angles), freqs * torch.cos(torch.pi / 2 + angles)],
dim=-1,
)
fy = torch.cat(
[freqs * torch.sin(angles), freqs * torch.sin(torch.pi / 2 + angles)],
dim=-1,
)
freqs_x.append(fx)
freqs_y.append(fy)
freqs_x = torch.stack(freqs_x, dim=0)
freqs_y = torch.stack(freqs_y, dim=0)
return torch.stack([freqs_x, freqs_y], dim=0)
class Skew(nn.Module):
"""Skew-symmetric matrix parameterization."""
def forward(self, x: Tensor) -> Tensor:
a = x.triu(1)
return a - a.transpose(-1, -2)
def right_inverse(self, x: Tensor) -> Tensor:
return x.triu(1)
class CayleySTRING(nn.Module):
"""Implements the Cayley-STRING positional encoding.
Based on "Learning the RoPEs: Better 2D and 3D Position Encodings with STRING"
(https://arxiv.org/abs/2502.02562).
Applies RoPE followed by multiplication with a learnable orthogonal matrix P
parameterized by the Cayley transform: P = (I - S)(I + S)^-1, where S is
a learnable skew-symmetric matrix.
Args:
dim (int): The feature dimension of the input tensor. Must be even.
max_seq_len (int): The maximum sequence length.
base (int): The base value for the RoPE frequency calculation. Defaults to 10000.
pos_dim (int): The dimensionality of the position vectors (e.g., 1 for 1D, 2 for 2D). Defaults to 1.
"""
def __init__(
self, dim: int, num_heads: int, pos_dim: int = 2, theta: float = 100.0
) -> None:
super().__init__()
assert dim % num_heads == 0, "Dimension must be divisible by num_heads."
head_dim = dim // num_heads
self.freqs = nn.Parameter(init_freqs(head_dim, num_heads, pos_dim, theta))
self.S = nn.Parameter(torch.zeros(head_dim, head_dim))
parametrize.register_parametrization(self, "S", Skew())
self.register_buffer("I", torch.eye(head_dim), persistent=False)
self.init_weights()
def init_weights(self) -> None:
self.S = nn.init.kaiming_uniform_(self.S, a=math.sqrt(5))
@parametrize.cached()
@torch.autocast("cuda", enabled=False)
def forward(self, x: Tensor, positions: Tensor) -> Tensor:
"""Apply Cayley-STRING positional encoding.
Args:
x ([b, h, n, d]): Input tensor.
positions ([b, n, pos_dim]): Positions tensor.
"""
# Compute (I + S)^-1 @ x
y = torch.linalg.solve(
self.I + self.S, rearrange(x.float(), "b h n d -> h d (b n)")
)
# change of basis
px = torch.matmul(self.I - self.S, y)
px = rearrange(px, "h d (b n) -> b h n d", b=x.size(0)).contiguous()
# apply RoPE-Mixed
angles = torch.einsum("bnk,khc->bhnc", positions, self.freqs)
freqs_cis = torch.polar(torch.ones_like(angles), angles)
px_ = torch.view_as_complex(rearrange(px, "... (d two) -> ... d two", two=2))
out = rearrange(torch.view_as_real(px_ * freqs_cis), "... d two -> ... (d two)")
return out.type_as(x)
def maybe_pad(x: Tensor, window_size: int) -> Tensor:
h, w = x.shape[1:3]
pad_right = (window_size - w % window_size) % window_size
pad_bottom = (window_size - h % window_size) % window_size
return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom))
@torch.autocast("cuda", enabled=False)
def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tensor:
pos = pos.sigmoid()
h, w = pos.shape[1:3]
anchor_x = torch.arange(w, dtype=torch.float32, device=pos.device) * step_x
anchor_y = torch.arange(h, dtype=torch.float32, device=pos.device) * step_y
absolute_x = pos[..., 0] * step_x + anchor_x
absolute_y = pos[..., 1] * step_y + anchor_y.unsqueeze(1)
return torch.stack((absolute_x, absolute_y), dim=-1)
def get_mask_windows(
height: int, width: int, window_size: int, shift_size: int, device: torch.device
) -> Tensor:
# Create indices for height and width regions
h_idx = torch.zeros(height, dtype=torch.long, device=device)
h_idx[height - window_size : height - shift_size] = 1
h_idx[height - shift_size :] = 2
w_idx = torch.zeros(width, dtype=torch.long, device=device)
w_idx[width - window_size : width - shift_size] = 1
w_idx[width - shift_size :] = 2
# Calculate region index for each pixel using broadcasting
mask = h_idx.unsqueeze(1) * 3 + w_idx.unsqueeze(0)
mask_windows = window_partition(mask[None, ..., None], window_size)
return rearrange(mask_windows, "n w1 w2 1 -> n (w1 w2)")
class WindowCrossAttention(nn.Module):
def __init__(
self,
dim: int,
src_dim: int,
tgt_window_size: int,
src_window_size: int,
num_heads: int,
src_shift_size: int = 0,
tgt_shift_size: int = 0,
dropout: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.tgt_window_size = tgt_window_size
self.src_window_size = src_window_size
self.src_shift_size = src_shift_size
self.tgt_shift_size = tgt_shift_size
self.dropout = dropout
self.pe = CayleySTRING(dim, num_heads)
self.query = nn.Linear(dim, dim, bias=False)
self.kv = nn.Linear(src_dim, dim * 2, bias=False)
self.wo = nn.Linear(dim, dim, bias=False)
def get_attn_mask(
self,
height: int,
width: int,
key_height: int,
key_width: int,
device: torch.device,
dtype: torch.dtype,
) -> Tensor | None:
if self.tgt_shift_size == 0:
return None
query_mask = get_mask_windows(
height, width, self.tgt_window_size, self.tgt_shift_size, device
)
key_mask = get_mask_windows(
key_height, key_width, self.src_window_size, self.src_shift_size, device
)
attn_mask = query_mask.unsqueeze(2) - key_mask.unsqueeze(1)
return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf)
def forward(
self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor
) -> Tensor:
b, h, w, c = tgt.shape
# pad to multiples of window size
tgt = maybe_pad(tgt, self.tgt_window_size)
src = maybe_pad(src, self.src_window_size)
tgt_coords = maybe_pad(tgt_coords, self.tgt_window_size)
src_coord = maybe_pad(src_coord, self.src_window_size)
h_pad, w_pad = tgt.shape[1:3]
src_h, src_w = src.shape[1:3]
# cyclic shift
if self.tgt_shift_size > 0:
tgt = tgt.roll(
shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2)
)
tgt_coords = tgt_coords.roll(
shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2)
)
if self.src_shift_size > 0:
src = src.roll(
shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2)
)
src_coord = src_coord.roll(
shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2)
)
# partition windows
tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
src = window_partition(src, self.src_window_size).flatten(1, 2)
tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
attn_mask = self.get_attn_mask(
h_pad, w_pad, src_h, src_w, tgt.device, tgt.dtype
)
if attn_mask is not None:
attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
# W-MCA/SW-MCA
q = rearrange(self.query(tgt), "b n (h d) -> b h n d", h=self.num_heads)
k, v = rearrange(
self.kv(src), "b n (two h d) -> two b h n d", two=2, h=self.num_heads
)
x = F.scaled_dot_product_attention(
query=self.pe(q, tgt_coords),
key=self.pe(k, src_coord),
value=v,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
)
tgt = self.wo(rearrange(x, "b h n d -> b n (h d)"))
# merge windows
tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c)
tgt = window_reverse(tgt, self.tgt_window_size, h_pad, w_pad)
# reverse cyclic shift
if self.tgt_shift_size > 0:
tgt = torch.roll(
tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2)
)
return tgt[:, :h, :w, :].contiguous() # remove padding
class WindowSelfAttention(nn.Module):
def __init__(
self,
dim: int,
window_size: int,
num_heads: int,
shift_size: int = 0,
dropout: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.dropout = dropout
self.pe = CayleySTRING(dim, num_heads)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.wo = nn.Linear(dim, dim, bias=False)
def get_attn_mask(
self, height: int, width: int, device: torch.device, dtype: torch.dtype
) -> Tensor | None:
if self.shift_size == 0:
return None
mask_windows = get_mask_windows(
height, width, self.window_size, self.shift_size, device
)
# Calculate the attention mask based on window differences
attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(1)
return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf)
def forward(self, x: Tensor, coords: Tensor) -> Tensor:
"""Forward function for Window Self-Attention.
Args:
x ([b, h, w, c]): Hidden states.
coords ([b, h, w, 2]): Absolute positions.
"""
b, h, w, c = x.shape
# pad to multiples of window size
x = maybe_pad(x, self.window_size)
coords = maybe_pad(coords, self.window_size)
h_pad, w_pad = x.shape[1:3]
# cyclic shift
if self.shift_size > 0:
x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
coords = coords.roll(
shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
# partition windows
x = window_partition(x, self.window_size).flatten(1, 2)
coords = window_partition(coords, self.window_size).flatten(1, 2)
attn_mask = self.get_attn_mask(h_pad, w_pad, x.device, x.dtype)
if attn_mask is not None:
attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
# W-MSA/SW-MSA
q, k, v = rearrange(
self.qkv(x), "b n (three h d) -> three b h n d", three=3, h=self.num_heads
)
x = F.scaled_dot_product_attention(
query=self.pe(q, coords),
key=self.pe(k, coords),
value=v,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
)
x = self.wo(rearrange(x, "b h n d -> b n (h d)"))
# merge windows
x = x.view(-1, self.window_size, self.window_size, c)
x = window_reverse(x, self.window_size, h_pad, w_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
return x[:, :h, :w, :].contiguous() # remove padding
class Block(nn.Module):
def __init__(
self,
dim: int,
src_dim: int,
num_heads: int,
window_size: int,
tgt_window_size: int,
src_window_size: int,
shift_size: int = 0,
tgt_shift_size: int = 0,
src_shift_size: int = 0,
dropout: float = 0.1,
) -> None:
super().__init__()
self.cross_attention = WindowCrossAttention(
dim,
src_dim,
num_heads=num_heads,
tgt_window_size=tgt_window_size,
src_window_size=src_window_size,
tgt_shift_size=tgt_shift_size,
src_shift_size=src_shift_size,
dropout=dropout,
)
self.cross_attention_norm = nn.LayerNorm(dim)
self.cross_attention_dropout = nn.Dropout(dropout)
self.self_attention = WindowSelfAttention(
dim, window_size, num_heads, shift_size, dropout=dropout
)
self.self_attention_norm = nn.LayerNorm(dim)
self.self_attention_dropout = nn.Dropout(dropout)
self.ffn = FeedForward(dim, dim * 4)
self.ffn_norm = nn.LayerNorm(dim)
self.ffn_dropout = nn.Dropout(dropout)
def forward(
self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords
) -> Tensor:
x = self.self_attention(tgt, tgt_coords)
tgt = self.self_attention_norm(tgt + self.self_attention_dropout(x))
x = self.cross_attention(tgt, src, tgt_coords, src_coords)
tgt = self.cross_attention_norm(tgt + self.cross_attention_dropout(x))
return self.ffn_norm(tgt + self.ffn_dropout(self.ffn(tgt)))
class Stage(nn.Module):
def __init__(
self,
dim: int,
src_dim: int,
depth: int,
num_heads: int,
window_size: int,
tgt_window_size: int,
src_window_size: int,
dropout: float = 0.0,
) -> None:
super().__init__()
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=dim,
src_dim=src_dim,
num_heads=num_heads,
window_size=window_size,
tgt_window_size=tgt_window_size,
src_window_size=src_window_size,
shift_size=0 if i % 2 == 0 else window_size // 2,
tgt_shift_size=0 if i % 2 == 0 else tgt_window_size // 2,
src_shift_size=0 if i % 2 == 0 else src_window_size // 2,
dropout=dropout,
)
self.blocks.append(block)
def forward(
self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor
) -> Tensor:
for block in self.blocks:
tgt = block(tgt, src, tgt_coords, src_coords)
return tgt
class LSPTransformer(nn.Module):
def __init__(self, config: LSPDetrConfig, feature_channels: list[int]) -> None:
super().__init__()
self.query_block_size = config.query_block_size
self.num_radial_distances = config.num_radial_distances
self.stages = nn.ModuleList()
for i, depth in enumerate(config.depths):
stage = Stage(
dim=config.dim,
src_dim=feature_channels[i],
depth=depth,
num_heads=config.num_heads,
window_size=config.window_size,
tgt_window_size=config.tgt_window_sizes[i],
src_window_size=config.src_window_sizes[i],
dropout=config.dropout,
)
self.stages.append(stage)
self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in feature_channels)
# output heads
self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
self.point_head = MLP(config.dim, config.dim, 2, 2)
self.radial_distances_head = MLP(
config.dim, config.dim, config.num_radial_distances, 2
)
self.init_weights()
def init_weights(self) -> None:
# initialize regression layers
nn.init.constant_(self.point_head[-1].weight, 0.0)
nn.init.constant_(self.point_head[-1].bias, 0.0)
def forward(
self,
tgt: Tensor,
ref_points: Tensor,
features: list[Tensor],
height: int,
width: int,
) -> dict[str, Tensor | list[dict[str, Tensor]]]:
src = []
src_coords = []
for i, feature in enumerate(features):
b, _, h, w = feature.shape
coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
src_coords.append(
relative_to_absolute_pos(
coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h)
)
)
logits_list: list[Tensor] = []
ref_points_list: list[Tensor] = []
radial_distances_list: list[Tensor] = []
new_ref_points = ref_points.clone() # for look forward twice
for i, stage in enumerate(self.stages):
tgt = stage(
tgt=tgt,
src=src[i],
tgt_coords=relative_to_absolute_pos(
ref_points, self.query_block_size, self.query_block_size
),
src_coords=src_coords[i],
)
# output heads
delta_point = self.point_head(tgt)
radial_distances = self.radial_distances_head(tgt)
logits = self.class_head(tgt)
ref_points_list.append(
relative_to_absolute_pos(
new_ref_points + delta_point,
step_x=self.query_block_size / width,
step_y=self.query_block_size / height,
).flatten(1, 2)
)
logits_list.append(logits.flatten(1, 2))
radial_distances_list.append(radial_distances.flatten(1, 2))
new_ref_points = ref_points + delta_point
ref_points = new_ref_points.detach()
return {
"logits": logits_list[-1],
"points": ref_points_list[-1],
"radial_distances": radial_distances_list[-1],
"absolute_points": relative_to_absolute_pos(
ref_points, self.query_block_size, self.query_block_size
).flatten(1, 2),
"aux_outputs": [
{
"logits": a,
"points": b,
"radial_distances": c,
}
for a, b, c in zip(
logits_list[:-1],
ref_points_list[:-1],
radial_distances_list[:-1],
strict=True,
)
],
}
class FeatureSampling(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.reduction = nn.Linear(in_dim, out_dim, bias=False)
self.norm = nn.LayerNorm(out_dim)
def forward(self, points: Tensor, feature: Tensor) -> Tensor:
x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
class LSPDetrModel(PreTrainedModel):
config_class = LSPDetrConfig
def __init__(self, config: LSPDetrConfig) -> None:
super().__init__(config)
self.query_block_size = config.query_block_size
self.backbone = load_backbone(config)
_, *feature_channels, neck = self.backbone.num_features
self.feature_sampling = FeatureSampling(neck, config.dim)
self.decode_head = LSPTransformer(config, feature_channels[::-1])
def forward(self, pixel_values: Tensor) -> dict[str, Tensor]:
b, _, h, w = pixel_values.shape
*features, neck = self.backbone(pixel_values).feature_maps
ref_points = torch.zeros(
b,
math.ceil(h / self.query_block_size),
math.ceil(w / self.query_block_size),
2,
dtype=torch.float32,
device=neck.device,
) # center positions
tgt = self.feature_sampling(
relative_to_absolute_pos(
ref_points, self.query_block_size, self.query_block_size
),
neck,
)
return self.decode_head(tgt, ref_points, features[::-1], h, w)