File size: 2,177 Bytes
b2efbe4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | """DETR decoder (§2.4).
``N_dec`` blocks of: cross-attention (GS tokens -> image tokens, with K/V
projected once and shared across layers), self-attention among GS tokens (with
the dynamic->static causal mask), and a per-token MLP. QK-norm + LayerScale
throughout; the final LayerNorm before the head is omitted (TokenGS).
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from mapgs.model.blocks import DecoderBlock, SharedImageKV
class DETRDecoder(nn.Module):
def __init__(
self,
dim: int = 1024,
depth: int = 24,
n_heads: int = 16,
mlp_ratio: float = 4.0,
qk_norm: bool = True,
layerscale_init: float = 1e-5,
ctx_dim: Optional[int] = None,
shared_kv: bool = True,
):
super().__init__()
self.shared_kv = shared_kv
ctx_dim = ctx_dim or dim
if shared_kv:
self.kv = SharedImageKV(dim, n_heads, ctx_dim, qk_norm)
else:
self.kv = nn.ModuleList(
[SharedImageKV(dim, n_heads, ctx_dim, qk_norm) for _ in range(depth)]
)
self.blocks = nn.ModuleList(
[DecoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)]
)
self.grad_checkpoint = False
def forward(
self,
tokens: torch.Tensor, # [B, T, C] GS query tokens
image_tokens: torch.Tensor, # [B, N_I, C]
self_mask: Optional[torch.Tensor] = None, # [B|1, 1|H, T, T] bool, True=keep
) -> torch.Tensor:
def run(blk, toks, kv):
if self.grad_checkpoint and self.training:
return torch.utils.checkpoint.checkpoint(blk, toks, kv, self_mask, use_reentrant=False)
return blk(toks, kv, self_mask=self_mask)
if self.shared_kv:
kv = self.kv(image_tokens) # projected once, reused across layers
for blk in self.blocks:
tokens = run(blk, tokens, kv)
else:
for blk, kvmod in zip(self.blocks, self.kv):
tokens = run(blk, tokens, kvmod(image_tokens))
return tokens
|