|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
from torch import Tensor, nn |
|
|
from torch.nn.utils import parametrize |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse |
|
|
from transformers.utils.backbone_utils import load_backbone |
|
|
|
|
|
from .configuration import LSPDetrConfig |
|
|
|
|
|
|
|
|
class MLP(nn.Sequential): |
|
|
"""Very simple multi-layer perceptron.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
hidden_dim: int, |
|
|
output_dim: int, |
|
|
num_layers: int, |
|
|
act_layer: type[nn.Module] = nn.GELU, |
|
|
dropout: float = 0.0, |
|
|
) -> None: |
|
|
assert num_layers > 1 |
|
|
|
|
|
layers = [] |
|
|
h = [hidden_dim] * (num_layers - 1) |
|
|
for n, k in zip([input_dim, *h], h, strict=False): |
|
|
layers.append(nn.Linear(n, k)) |
|
|
layers.append(act_layer()) |
|
|
if dropout > 0: |
|
|
layers.append(nn.Dropout(dropout)) |
|
|
|
|
|
layers.append(nn.Linear(hidden_dim, output_dim)) |
|
|
super().__init__(*layers) |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""FeedForward module. |
|
|
|
|
|
Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None: |
|
|
"""Initialize the FeedForward module. |
|
|
|
|
|
Args: |
|
|
dim (int): Input dimension. |
|
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
|
|
""" |
|
|
super().__init__() |
|
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
|
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
|
|
|
def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor: |
|
|
"""Taken from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py.""" |
|
|
freqs_x = [] |
|
|
freqs_y = [] |
|
|
freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim)) |
|
|
for _ in range(num_heads): |
|
|
angles = torch.rand(1) * 2 * torch.pi |
|
|
fx = torch.cat( |
|
|
[freqs * torch.cos(angles), freqs * torch.cos(torch.pi / 2 + angles)], |
|
|
dim=-1, |
|
|
) |
|
|
fy = torch.cat( |
|
|
[freqs * torch.sin(angles), freqs * torch.sin(torch.pi / 2 + angles)], |
|
|
dim=-1, |
|
|
) |
|
|
freqs_x.append(fx) |
|
|
freqs_y.append(fy) |
|
|
freqs_x = torch.stack(freqs_x, dim=0) |
|
|
freqs_y = torch.stack(freqs_y, dim=0) |
|
|
return torch.stack([freqs_x, freqs_y], dim=0) |
|
|
|
|
|
|
|
|
class Skew(nn.Module): |
|
|
"""Skew-symmetric matrix parameterization.""" |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
a = x.triu(1) |
|
|
return a - a.transpose(-1, -2) |
|
|
|
|
|
def right_inverse(self, x: Tensor) -> Tensor: |
|
|
return x.triu(1) |
|
|
|
|
|
|
|
|
class CayleySTRING(nn.Module): |
|
|
"""Implements the Cayley-STRING positional encoding. |
|
|
|
|
|
Based on "Learning the RoPEs: Better 2D and 3D Position Encodings with STRING" |
|
|
(https://arxiv.org/abs/2502.02562). |
|
|
|
|
|
Applies RoPE followed by multiplication with a learnable orthogonal matrix P |
|
|
parameterized by the Cayley transform: P = (I - S)(I + S)^-1, where S is |
|
|
a learnable skew-symmetric matrix. |
|
|
|
|
|
Args: |
|
|
dim (int): The feature dimension of the input tensor. Must be even. |
|
|
max_seq_len (int): The maximum sequence length. |
|
|
base (int): The base value for the RoPE frequency calculation. Defaults to 10000. |
|
|
pos_dim (int): The dimensionality of the position vectors (e.g., 1 for 1D, 2 for 2D). Defaults to 1. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, dim: int, num_heads: int, pos_dim: int = 2, theta: float = 100.0 |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, "Dimension must be divisible by num_heads." |
|
|
|
|
|
head_dim = dim // num_heads |
|
|
|
|
|
self.freqs = nn.Parameter(init_freqs(head_dim, num_heads, pos_dim, theta)) |
|
|
|
|
|
self.S = nn.Parameter(torch.zeros(head_dim, head_dim)) |
|
|
parametrize.register_parametrization(self, "S", Skew()) |
|
|
|
|
|
self.register_buffer("I", torch.eye(head_dim), persistent=False) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self) -> None: |
|
|
self.S = nn.init.kaiming_uniform_(self.S, a=math.sqrt(5)) |
|
|
|
|
|
@parametrize.cached() |
|
|
@torch.autocast("cuda", enabled=False) |
|
|
def forward(self, x: Tensor, positions: Tensor) -> Tensor: |
|
|
"""Apply Cayley-STRING positional encoding. |
|
|
|
|
|
Args: |
|
|
x ([b, h, n, d]): Input tensor. |
|
|
positions ([b, n, pos_dim]): Positions tensor. |
|
|
""" |
|
|
|
|
|
y = torch.linalg.solve( |
|
|
self.I + self.S, rearrange(x.float(), "b h n d -> h d (b n)") |
|
|
) |
|
|
|
|
|
|
|
|
px = torch.matmul(self.I - self.S, y) |
|
|
px = rearrange(px, "h d (b n) -> b h n d", b=x.size(0)).contiguous() |
|
|
|
|
|
|
|
|
angles = torch.einsum("bnk,khc->bhnc", positions, self.freqs) |
|
|
freqs_cis = torch.polar(torch.ones_like(angles), angles) |
|
|
px_ = torch.view_as_complex(rearrange(px, "... (d two) -> ... d two", two=2)) |
|
|
out = rearrange(torch.view_as_real(px_ * freqs_cis), "... d two -> ... (d two)") |
|
|
|
|
|
return out.type_as(x) |
|
|
|
|
|
|
|
|
def maybe_pad(x: Tensor, window_size: int) -> Tensor: |
|
|
h, w = x.shape[1:3] |
|
|
pad_right = (window_size - w % window_size) % window_size |
|
|
pad_bottom = (window_size - h % window_size) % window_size |
|
|
return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom)) |
|
|
|
|
|
|
|
|
@torch.autocast("cuda", enabled=False) |
|
|
def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tensor: |
|
|
pos = pos.sigmoid() |
|
|
h, w = pos.shape[1:3] |
|
|
|
|
|
anchor_x = torch.arange(w, dtype=torch.float32, device=pos.device) * step_x |
|
|
anchor_y = torch.arange(h, dtype=torch.float32, device=pos.device) * step_y |
|
|
|
|
|
absolute_x = pos[..., 0] * step_x + anchor_x |
|
|
absolute_y = pos[..., 1] * step_y + anchor_y.unsqueeze(1) |
|
|
return torch.stack((absolute_x, absolute_y), dim=-1) |
|
|
|
|
|
|
|
|
def get_mask_windows( |
|
|
height: int, width: int, window_size: int, shift_size: int, device: torch.device |
|
|
) -> Tensor: |
|
|
|
|
|
h_idx = torch.zeros(height, dtype=torch.long, device=device) |
|
|
h_idx[height - window_size : height - shift_size] = 1 |
|
|
h_idx[height - shift_size :] = 2 |
|
|
|
|
|
w_idx = torch.zeros(width, dtype=torch.long, device=device) |
|
|
w_idx[width - window_size : width - shift_size] = 1 |
|
|
w_idx[width - shift_size :] = 2 |
|
|
|
|
|
|
|
|
mask = h_idx.unsqueeze(1) * 3 + w_idx.unsqueeze(0) |
|
|
|
|
|
mask_windows = window_partition(mask[None, ..., None], window_size) |
|
|
return rearrange(mask_windows, "n w1 w2 1 -> n (w1 w2)") |
|
|
|
|
|
|
|
|
class WindowCrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
src_dim: int, |
|
|
tgt_window_size: int, |
|
|
src_window_size: int, |
|
|
num_heads: int, |
|
|
src_shift_size: int = 0, |
|
|
tgt_shift_size: int = 0, |
|
|
dropout: float = 0.0, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.tgt_window_size = tgt_window_size |
|
|
self.src_window_size = src_window_size |
|
|
self.src_shift_size = src_shift_size |
|
|
self.tgt_shift_size = tgt_shift_size |
|
|
self.dropout = dropout |
|
|
|
|
|
self.pe = CayleySTRING(dim, num_heads) |
|
|
self.query = nn.Linear(dim, dim, bias=False) |
|
|
self.kv = nn.Linear(src_dim, dim * 2, bias=False) |
|
|
self.wo = nn.Linear(dim, dim, bias=False) |
|
|
|
|
|
def get_attn_mask( |
|
|
self, |
|
|
height: int, |
|
|
width: int, |
|
|
key_height: int, |
|
|
key_width: int, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
) -> Tensor | None: |
|
|
if self.tgt_shift_size == 0: |
|
|
return None |
|
|
|
|
|
query_mask = get_mask_windows( |
|
|
height, width, self.tgt_window_size, self.tgt_shift_size, device |
|
|
) |
|
|
key_mask = get_mask_windows( |
|
|
key_height, key_width, self.src_window_size, self.src_shift_size, device |
|
|
) |
|
|
|
|
|
attn_mask = query_mask.unsqueeze(2) - key_mask.unsqueeze(1) |
|
|
return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf) |
|
|
|
|
|
def forward( |
|
|
self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor |
|
|
) -> Tensor: |
|
|
b, h, w, c = tgt.shape |
|
|
|
|
|
|
|
|
tgt = maybe_pad(tgt, self.tgt_window_size) |
|
|
src = maybe_pad(src, self.src_window_size) |
|
|
tgt_coords = maybe_pad(tgt_coords, self.tgt_window_size) |
|
|
src_coord = maybe_pad(src_coord, self.src_window_size) |
|
|
h_pad, w_pad = tgt.shape[1:3] |
|
|
src_h, src_w = src.shape[1:3] |
|
|
|
|
|
|
|
|
if self.tgt_shift_size > 0: |
|
|
tgt = tgt.roll( |
|
|
shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2) |
|
|
) |
|
|
tgt_coords = tgt_coords.roll( |
|
|
shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2) |
|
|
) |
|
|
|
|
|
if self.src_shift_size > 0: |
|
|
src = src.roll( |
|
|
shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2) |
|
|
) |
|
|
src_coord = src_coord.roll( |
|
|
shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2) |
|
|
) |
|
|
|
|
|
|
|
|
tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2) |
|
|
src = window_partition(src, self.src_window_size).flatten(1, 2) |
|
|
tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2) |
|
|
src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2) |
|
|
|
|
|
attn_mask = self.get_attn_mask( |
|
|
h_pad, w_pad, src_h, src_w, tgt.device, tgt.dtype |
|
|
) |
|
|
|
|
|
if attn_mask is not None: |
|
|
attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads) |
|
|
|
|
|
|
|
|
q = rearrange(self.query(tgt), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
k, v = rearrange( |
|
|
self.kv(src), "b n (two h d) -> two b h n d", two=2, h=self.num_heads |
|
|
) |
|
|
x = F.scaled_dot_product_attention( |
|
|
query=self.pe(q, tgt_coords), |
|
|
key=self.pe(k, src_coord), |
|
|
value=v, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
) |
|
|
tgt = self.wo(rearrange(x, "b h n d -> b n (h d)")) |
|
|
|
|
|
|
|
|
tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c) |
|
|
tgt = window_reverse(tgt, self.tgt_window_size, h_pad, w_pad) |
|
|
|
|
|
|
|
|
if self.tgt_shift_size > 0: |
|
|
tgt = torch.roll( |
|
|
tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2) |
|
|
) |
|
|
|
|
|
return tgt[:, :h, :w, :].contiguous() |
|
|
|
|
|
|
|
|
class WindowSelfAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
window_size: int, |
|
|
num_heads: int, |
|
|
shift_size: int = 0, |
|
|
dropout: float = 0.0, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.window_size = window_size |
|
|
self.shift_size = shift_size |
|
|
self.dropout = dropout |
|
|
|
|
|
self.pe = CayleySTRING(dim, num_heads) |
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=False) |
|
|
self.wo = nn.Linear(dim, dim, bias=False) |
|
|
|
|
|
def get_attn_mask( |
|
|
self, height: int, width: int, device: torch.device, dtype: torch.dtype |
|
|
) -> Tensor | None: |
|
|
if self.shift_size == 0: |
|
|
return None |
|
|
|
|
|
mask_windows = get_mask_windows( |
|
|
height, width, self.window_size, self.shift_size, device |
|
|
) |
|
|
|
|
|
attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(1) |
|
|
return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf) |
|
|
|
|
|
def forward(self, x: Tensor, coords: Tensor) -> Tensor: |
|
|
"""Forward function for Window Self-Attention. |
|
|
|
|
|
Args: |
|
|
x ([b, h, w, c]): Hidden states. |
|
|
coords ([b, h, w, 2]): Absolute positions. |
|
|
""" |
|
|
b, h, w, c = x.shape |
|
|
|
|
|
|
|
|
x = maybe_pad(x, self.window_size) |
|
|
coords = maybe_pad(coords, self.window_size) |
|
|
h_pad, w_pad = x.shape[1:3] |
|
|
|
|
|
|
|
|
if self.shift_size > 0: |
|
|
x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
|
|
coords = coords.roll( |
|
|
shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) |
|
|
) |
|
|
|
|
|
|
|
|
x = window_partition(x, self.window_size).flatten(1, 2) |
|
|
coords = window_partition(coords, self.window_size).flatten(1, 2) |
|
|
|
|
|
attn_mask = self.get_attn_mask(h_pad, w_pad, x.device, x.dtype) |
|
|
if attn_mask is not None: |
|
|
attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads) |
|
|
|
|
|
|
|
|
q, k, v = rearrange( |
|
|
self.qkv(x), "b n (three h d) -> three b h n d", three=3, h=self.num_heads |
|
|
) |
|
|
x = F.scaled_dot_product_attention( |
|
|
query=self.pe(q, coords), |
|
|
key=self.pe(k, coords), |
|
|
value=v, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
) |
|
|
x = self.wo(rearrange(x, "b h n d -> b n (h d)")) |
|
|
|
|
|
|
|
|
x = x.view(-1, self.window_size, self.window_size, c) |
|
|
x = window_reverse(x, self.window_size, h_pad, w_pad) |
|
|
|
|
|
|
|
|
if self.shift_size > 0: |
|
|
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
|
|
|
|
|
return x[:, :h, :w, :].contiguous() |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
src_dim: int, |
|
|
num_heads: int, |
|
|
window_size: int, |
|
|
tgt_window_size: int, |
|
|
src_window_size: int, |
|
|
shift_size: int = 0, |
|
|
tgt_shift_size: int = 0, |
|
|
src_shift_size: int = 0, |
|
|
dropout: float = 0.1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.cross_attention = WindowCrossAttention( |
|
|
dim, |
|
|
src_dim, |
|
|
num_heads=num_heads, |
|
|
tgt_window_size=tgt_window_size, |
|
|
src_window_size=src_window_size, |
|
|
tgt_shift_size=tgt_shift_size, |
|
|
src_shift_size=src_shift_size, |
|
|
dropout=dropout, |
|
|
) |
|
|
self.cross_attention_norm = nn.LayerNorm(dim) |
|
|
self.cross_attention_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.self_attention = WindowSelfAttention( |
|
|
dim, window_size, num_heads, shift_size, dropout=dropout |
|
|
) |
|
|
self.self_attention_norm = nn.LayerNorm(dim) |
|
|
self.self_attention_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.ffn = FeedForward(dim, dim * 4) |
|
|
self.ffn_norm = nn.LayerNorm(dim) |
|
|
self.ffn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords |
|
|
) -> Tensor: |
|
|
x = self.self_attention(tgt, tgt_coords) |
|
|
tgt = self.self_attention_norm(tgt + self.self_attention_dropout(x)) |
|
|
|
|
|
x = self.cross_attention(tgt, src, tgt_coords, src_coords) |
|
|
tgt = self.cross_attention_norm(tgt + self.cross_attention_dropout(x)) |
|
|
|
|
|
return self.ffn_norm(tgt + self.ffn_dropout(self.ffn(tgt))) |
|
|
|
|
|
|
|
|
class Stage(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
src_dim: int, |
|
|
depth: int, |
|
|
num_heads: int, |
|
|
window_size: int, |
|
|
tgt_window_size: int, |
|
|
src_window_size: int, |
|
|
dropout: float = 0.0, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.blocks = nn.ModuleList() |
|
|
for i in range(depth): |
|
|
block = Block( |
|
|
dim=dim, |
|
|
src_dim=src_dim, |
|
|
num_heads=num_heads, |
|
|
window_size=window_size, |
|
|
tgt_window_size=tgt_window_size, |
|
|
src_window_size=src_window_size, |
|
|
shift_size=0 if i % 2 == 0 else window_size // 2, |
|
|
tgt_shift_size=0 if i % 2 == 0 else tgt_window_size // 2, |
|
|
src_shift_size=0 if i % 2 == 0 else src_window_size // 2, |
|
|
dropout=dropout, |
|
|
) |
|
|
self.blocks.append(block) |
|
|
|
|
|
def forward( |
|
|
self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor |
|
|
) -> Tensor: |
|
|
for block in self.blocks: |
|
|
tgt = block(tgt, src, tgt_coords, src_coords) |
|
|
return tgt |
|
|
|
|
|
|
|
|
class LSPTransformer(nn.Module): |
|
|
def __init__(self, config: LSPDetrConfig, feature_channels: list[int]) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.query_block_size = config.query_block_size |
|
|
self.num_radial_distances = config.num_radial_distances |
|
|
|
|
|
self.stages = nn.ModuleList() |
|
|
for i, depth in enumerate(config.depths): |
|
|
stage = Stage( |
|
|
dim=config.dim, |
|
|
src_dim=feature_channels[i], |
|
|
depth=depth, |
|
|
num_heads=config.num_heads, |
|
|
window_size=config.window_size, |
|
|
tgt_window_size=config.tgt_window_sizes[i], |
|
|
src_window_size=config.src_window_sizes[i], |
|
|
dropout=config.dropout, |
|
|
) |
|
|
self.stages.append(stage) |
|
|
|
|
|
self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in feature_channels) |
|
|
|
|
|
|
|
|
self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False) |
|
|
self.point_head = MLP(config.dim, config.dim, 2, 2) |
|
|
self.radial_distances_head = MLP( |
|
|
config.dim, config.dim, config.num_radial_distances, 2 |
|
|
) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self) -> None: |
|
|
|
|
|
nn.init.constant_(self.point_head[-1].weight, 0.0) |
|
|
nn.init.constant_(self.point_head[-1].bias, 0.0) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tgt: Tensor, |
|
|
ref_points: Tensor, |
|
|
features: list[Tensor], |
|
|
height: int, |
|
|
width: int, |
|
|
) -> dict[str, Tensor | list[dict[str, Tensor]]]: |
|
|
src = [] |
|
|
src_coords = [] |
|
|
for i, feature in enumerate(features): |
|
|
b, _, h, w = feature.shape |
|
|
coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device) |
|
|
src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c"))) |
|
|
src_coords.append( |
|
|
relative_to_absolute_pos( |
|
|
coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h) |
|
|
) |
|
|
) |
|
|
|
|
|
logits_list: list[Tensor] = [] |
|
|
ref_points_list: list[Tensor] = [] |
|
|
radial_distances_list: list[Tensor] = [] |
|
|
|
|
|
new_ref_points = ref_points.clone() |
|
|
for i, stage in enumerate(self.stages): |
|
|
tgt = stage( |
|
|
tgt=tgt, |
|
|
src=src[i], |
|
|
tgt_coords=relative_to_absolute_pos( |
|
|
ref_points, self.query_block_size, self.query_block_size |
|
|
), |
|
|
src_coords=src_coords[i], |
|
|
) |
|
|
|
|
|
|
|
|
delta_point = self.point_head(tgt) |
|
|
radial_distances = self.radial_distances_head(tgt) |
|
|
logits = self.class_head(tgt) |
|
|
|
|
|
ref_points_list.append( |
|
|
relative_to_absolute_pos( |
|
|
new_ref_points + delta_point, |
|
|
step_x=self.query_block_size / width, |
|
|
step_y=self.query_block_size / height, |
|
|
).flatten(1, 2) |
|
|
) |
|
|
logits_list.append(logits.flatten(1, 2)) |
|
|
radial_distances_list.append(radial_distances.flatten(1, 2)) |
|
|
|
|
|
new_ref_points = ref_points + delta_point |
|
|
ref_points = new_ref_points.detach() |
|
|
|
|
|
return { |
|
|
"logits": logits_list[-1], |
|
|
"points": ref_points_list[-1], |
|
|
"radial_distances": radial_distances_list[-1], |
|
|
"absolute_points": relative_to_absolute_pos( |
|
|
ref_points, self.query_block_size, self.query_block_size |
|
|
).flatten(1, 2), |
|
|
"aux_outputs": [ |
|
|
{ |
|
|
"logits": a, |
|
|
"points": b, |
|
|
"radial_distances": c, |
|
|
} |
|
|
for a, b, c in zip( |
|
|
logits_list[:-1], |
|
|
ref_points_list[:-1], |
|
|
radial_distances_list[:-1], |
|
|
strict=True, |
|
|
) |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
class FeatureSampling(nn.Module): |
|
|
def __init__(self, in_dim: int, out_dim: int) -> None: |
|
|
super().__init__() |
|
|
self.reduction = nn.Linear(in_dim, out_dim, bias=False) |
|
|
self.norm = nn.LayerNorm(out_dim) |
|
|
|
|
|
def forward(self, points: Tensor, feature: Tensor) -> Tensor: |
|
|
x = F.grid_sample(feature, points * 2 - 1, align_corners=False) |
|
|
return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c"))) |
|
|
|
|
|
|
|
|
class LSPDetrModel(PreTrainedModel): |
|
|
config_class = LSPDetrConfig |
|
|
|
|
|
def __init__(self, config: LSPDetrConfig) -> None: |
|
|
super().__init__(config) |
|
|
self.query_block_size = config.query_block_size |
|
|
|
|
|
self.backbone = load_backbone(config) |
|
|
_, *feature_channels, neck = self.backbone.num_features |
|
|
|
|
|
self.feature_sampling = FeatureSampling(neck, config.dim) |
|
|
self.decode_head = LSPTransformer(config, feature_channels[::-1]) |
|
|
|
|
|
def forward(self, pixel_values: Tensor) -> dict[str, Tensor]: |
|
|
b, _, h, w = pixel_values.shape |
|
|
|
|
|
*features, neck = self.backbone(pixel_values).feature_maps |
|
|
|
|
|
ref_points = torch.zeros( |
|
|
b, |
|
|
math.ceil(h / self.query_block_size), |
|
|
math.ceil(w / self.query_block_size), |
|
|
2, |
|
|
dtype=torch.float32, |
|
|
device=neck.device, |
|
|
) |
|
|
tgt = self.feature_sampling( |
|
|
relative_to_absolute_pos( |
|
|
ref_points, self.query_block_size, self.query_block_size |
|
|
), |
|
|
neck, |
|
|
) |
|
|
|
|
|
return self.decode_head(tgt, ref_points, features[::-1], h, w) |
|
|
|