File size: 2,177 Bytes
b2efbe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""DETR decoder (§2.4).

``N_dec`` blocks of: cross-attention (GS tokens -> image tokens, with K/V
projected once and shared across layers), self-attention among GS tokens (with
the dynamic->static causal mask), and a per-token MLP. QK-norm + LayerScale
throughout; the final LayerNorm before the head is omitted (TokenGS).
"""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn

from mapgs.model.blocks import DecoderBlock, SharedImageKV


class DETRDecoder(nn.Module):
    def __init__(
        self,
        dim: int = 1024,
        depth: int = 24,
        n_heads: int = 16,
        mlp_ratio: float = 4.0,
        qk_norm: bool = True,
        layerscale_init: float = 1e-5,
        ctx_dim: Optional[int] = None,
        shared_kv: bool = True,
    ):
        super().__init__()
        self.shared_kv = shared_kv
        ctx_dim = ctx_dim or dim
        if shared_kv:
            self.kv = SharedImageKV(dim, n_heads, ctx_dim, qk_norm)
        else:
            self.kv = nn.ModuleList(
                [SharedImageKV(dim, n_heads, ctx_dim, qk_norm) for _ in range(depth)]
            )
        self.blocks = nn.ModuleList(
            [DecoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)]
        )
        self.grad_checkpoint = False

    def forward(
        self,
        tokens: torch.Tensor,        # [B, T, C] GS query tokens
        image_tokens: torch.Tensor,  # [B, N_I, C]
        self_mask: Optional[torch.Tensor] = None,  # [B|1, 1|H, T, T] bool, True=keep
    ) -> torch.Tensor:
        def run(blk, toks, kv):
            if self.grad_checkpoint and self.training:
                return torch.utils.checkpoint.checkpoint(blk, toks, kv, self_mask, use_reentrant=False)
            return blk(toks, kv, self_mask=self_mask)

        if self.shared_kv:
            kv = self.kv(image_tokens)  # projected once, reused across layers
            for blk in self.blocks:
                tokens = run(blk, tokens, kv)
        else:
            for blk, kvmod in zip(self.blocks, self.kv):
                tokens = run(blk, tokens, kvmod(image_tokens))
        return tokens