| """DETR decoder (§2.4). |
| |
| ``N_dec`` blocks of: cross-attention (GS tokens -> image tokens, with K/V |
| projected once and shared across layers), self-attention among GS tokens (with |
| the dynamic->static causal mask), and a per-token MLP. QK-norm + LayerScale |
| throughout; the final LayerNorm before the head is omitted (TokenGS). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from mapgs.model.blocks import DecoderBlock, SharedImageKV |
|
|
|
|
| class DETRDecoder(nn.Module): |
| def __init__( |
| self, |
| dim: int = 1024, |
| depth: int = 24, |
| n_heads: int = 16, |
| mlp_ratio: float = 4.0, |
| qk_norm: bool = True, |
| layerscale_init: float = 1e-5, |
| ctx_dim: Optional[int] = None, |
| shared_kv: bool = True, |
| ): |
| super().__init__() |
| self.shared_kv = shared_kv |
| ctx_dim = ctx_dim or dim |
| if shared_kv: |
| self.kv = SharedImageKV(dim, n_heads, ctx_dim, qk_norm) |
| else: |
| self.kv = nn.ModuleList( |
| [SharedImageKV(dim, n_heads, ctx_dim, qk_norm) for _ in range(depth)] |
| ) |
| self.blocks = nn.ModuleList( |
| [DecoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)] |
| ) |
| self.grad_checkpoint = False |
|
|
| def forward( |
| self, |
| tokens: torch.Tensor, |
| image_tokens: torch.Tensor, |
| self_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| def run(blk, toks, kv): |
| if self.grad_checkpoint and self.training: |
| return torch.utils.checkpoint.checkpoint(blk, toks, kv, self_mask, use_reentrant=False) |
| return blk(toks, kv, self_mask=self_mask) |
|
|
| if self.shared_kv: |
| kv = self.kv(image_tokens) |
| for blk in self.blocks: |
| tokens = run(blk, tokens, kv) |
| else: |
| for blk, kvmod in zip(self.blocks, self.kv): |
| tokens = run(blk, tokens, kvmod(image_tokens)) |
| return tokens |
|
|