File size: 6,102 Bytes
1b703d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Dense cross-attention block used by the DINAC-AE class-token head."""

from __future__ import annotations

from dataclasses import dataclass

from torch import Tensor, nn

from common.norms import RMSNorm
from dit.attention_blocks import CrossAttentionCore
from dit.mlp import build_dit_mlp, reset_module_parameters
from dit.mlp_types import MLPType


@dataclass
class CrossAttentionConfig:
    """Configuration for the exported dense cross-attention block."""

    n_heads: int = 16
    head_dim: int | None = None
    query_extra_dim: int = 0
    context_extra_dim: int = 0
    key_extra_dim: int = 0
    mlp_ratio: float = 2.0
    attn_dropout: float = 0.0
    mlp_type: MLPType = MLPType.GELU
    activation_config: object | None = None
    use_norms: bool = True
    block_index: int = 0
    use_attn_residual: bool = True


class CrossAttentionBlock(nn.Module):
    """Dense pre-norm cross-attention plus residual MLP."""

    query_dim: int
    context_dim: int
    query_extra_dim: int
    context_extra_dim: int
    key_extra_dim: int
    n_heads: int
    head_dim: int
    attn_dim: int
    use_norms: bool
    attn_dropout: float
    use_attn_residual: bool
    query_norm: RMSNorm | None
    context_norm: RMSNorm | None
    mlp_norm: RMSNorm | None
    q_proj: nn.Linear
    attn_core: CrossAttentionCore
    kv_proj: nn.Linear
    out_proj: nn.Linear
    mlp: nn.Module

    def __init__(
        self,
        *,
        query_dim: int,
        context_dim: int,
        cfg: CrossAttentionConfig,
    ) -> None:
        super().__init__()
        n_heads = int(cfg.n_heads)
        if cfg.head_dim is None:
            if query_dim % n_heads != 0:
                raise ValueError("query_dim must be divisible by n_heads")
            head_dim = query_dim // n_heads
        else:
            head_dim = int(cfg.head_dim)
        self.query_dim = int(query_dim)
        self.context_dim = int(context_dim)
        self.query_extra_dim = int(cfg.query_extra_dim)
        self.context_extra_dim = int(cfg.context_extra_dim)
        self.key_extra_dim = int(cfg.key_extra_dim)
        self.n_heads = n_heads
        self.head_dim = int(head_dim)
        self.attn_dim = int(self.n_heads * self.head_dim)
        self.use_norms = bool(cfg.use_norms)
        self.attn_dropout = float(cfg.attn_dropout)
        if not cfg.use_attn_residual:
            raise ValueError("DINAC-AE export requires attention residuals")
        self.use_attn_residual = True
        self.query_norm = RMSNorm(self.query_dim) if self.use_norms else None
        self.context_norm = RMSNorm(self.context_dim) if self.use_norms else None
        self.mlp_norm = RMSNorm(query_dim) if self.use_norms else None
        self.q_proj = nn.Linear(
            self.query_dim + self.query_extra_dim, self.attn_dim, bias=False
        )
        self.attn_core = CrossAttentionCore(
            query_dim=query_dim,
            context_dim=context_dim,
            context_extra_dim=self.context_extra_dim,
            key_extra_dim=self.key_extra_dim,
            n_heads=self.n_heads,
            head_dim=self.head_dim,
            attn_dropout=self.attn_dropout,
        )
        self.kv_proj = self.attn_core.kv_proj
        self.out_proj = self.attn_core.out_proj
        hidden = int(round(cfg.mlp_ratio * query_dim))
        self.mlp = build_dit_mlp(
            mlp_type=cfg.mlp_type,
            in_features=query_dim,
            hidden_budget=hidden,
            activation_config=cfg.activation_config,
            block_index=int(cfg.block_index),
            bias_up=False,
            bias_down=False,
        )
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Reset projections and MLP parameters."""

        nn.init.xavier_uniform_(self.q_proj.weight)
        self.attn_core.reset_parameters()
        reset_module_parameters(self.mlp)

    def forward(
        self,
        query: Tensor,
        context: Tensor,
        *,
        query_extra: Tensor | None = None,
        context_extra: Tensor | None = None,
        key_extra: Tensor | None = None,
        key_padding_mask: Tensor | None = None,
    ) -> Tensor:  # type: ignore[override]
        """Run dense cross-attention followed by the residual MLP."""

        query_tokens = self.query_norm(query) if self.query_norm is not None else query
        if query_extra is not None:
            q_in = query_tokens.new_empty(
                *query_tokens.shape[:-1],
                int(query_tokens.shape[-1]) + int(query_extra.shape[-1]),
            )
            q_in[..., : int(query_tokens.shape[-1])] = query_tokens
            q_in[..., int(query_tokens.shape[-1]) :] = query_extra
        else:
            q_in = query_tokens
        context_tokens = (
            self.context_norm(context) if self.context_norm is not None else context
        )
        if context_extra is not None:
            kv_tokens = context_tokens.new_empty(
                *context_tokens.shape[:-1],
                int(context_tokens.shape[-1]) + int(context_extra.shape[-1]),
            )
            kv_tokens[..., : int(context_tokens.shape[-1])] = context_tokens
            kv_tokens[..., int(context_tokens.shape[-1]) :] = context_extra
        else:
            kv_tokens = context_tokens
        q_attn_tokens = self.q_proj(q_in)
        attn_out = self.attn_core(
            q_attn_tokens,
            kv_tokens,
            training=self.training,
            key_extra=key_extra,
            key_padding_mask=key_padding_mask,
        )
        tokens = query + attn_out
        mlp_in = self.mlp_norm(tokens) if self.mlp_norm is not None else tokens
        return tokens + self.mlp(mlp_in)

    def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
        """No-op hook kept for the token-alignment head API."""

        _ = fullgraph, dynamic

    def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
        """No-op hook kept for the token-alignment head API."""

        _ = fullgraph, dynamic


__all__ = ["CrossAttentionBlock", "CrossAttentionConfig"]