"""DETR decoder (ยง2.4). ``N_dec`` blocks of: cross-attention (GS tokens -> image tokens, with K/V projected once and shared across layers), self-attention among GS tokens (with the dynamic->static causal mask), and a per-token MLP. QK-norm + LayerScale throughout; the final LayerNorm before the head is omitted (TokenGS). """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn from mapgs.model.blocks import DecoderBlock, SharedImageKV class DETRDecoder(nn.Module): def __init__( self, dim: int = 1024, depth: int = 24, n_heads: int = 16, mlp_ratio: float = 4.0, qk_norm: bool = True, layerscale_init: float = 1e-5, ctx_dim: Optional[int] = None, shared_kv: bool = True, ): super().__init__() self.shared_kv = shared_kv ctx_dim = ctx_dim or dim if shared_kv: self.kv = SharedImageKV(dim, n_heads, ctx_dim, qk_norm) else: self.kv = nn.ModuleList( [SharedImageKV(dim, n_heads, ctx_dim, qk_norm) for _ in range(depth)] ) self.blocks = nn.ModuleList( [DecoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)] ) self.grad_checkpoint = False def forward( self, tokens: torch.Tensor, # [B, T, C] GS query tokens image_tokens: torch.Tensor, # [B, N_I, C] self_mask: Optional[torch.Tensor] = None, # [B|1, 1|H, T, T] bool, True=keep ) -> torch.Tensor: def run(blk, toks, kv): if self.grad_checkpoint and self.training: return torch.utils.checkpoint.checkpoint(blk, toks, kv, self_mask, use_reentrant=False) return blk(toks, kv, self_mask=self_mask) if self.shared_kv: kv = self.kv(image_tokens) # projected once, reused across layers for blk in self.blocks: tokens = run(blk, tokens, kv) else: for blk, kvmod in zip(self.blocks, self.kv): tokens = run(blk, tokens, kvmod(image_tokens)) return tokens