File size: 8,506 Bytes
922bb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
ARModel — standard decoder-only (GPT-2 style) Transformer for next-token prediction.

Baseline to compare against SAD / Block-AR diffusion at matched scale:
  - Same hidden_size / n_blocks / n_heads / seq_len as SADModel
  - Same RoPE (reused from dit_components)
  - Standard pre-LN blocks with causal self-attention + GELU MLP
  - Untied token embedding / output head, matching Block-AR parameterization
  - No adaLN / no timestep conditioning / no DiT modulation

Inference:
  forward(input_ids) — full-sequence forward, used by training / eval / the
                       first (prompt) step of generation.
  forward_cached(input_ids, past_kv_list=None) — returns
                       (logits, new_kv_list). Used for left-to-right generation
                       with an incrementally grown KV cache.

  KV cache layout: list of length n_blocks; each entry is (k, v) with shape
    [B, H, S_cache, head_dim].
  Max total length is `max_seq_len` (RoPE is precomputed for that length).
"""

from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from .dit_components import Rotary, apply_rotary_pos_emb


KVPair = Tuple[torch.Tensor, torch.Tensor]


class ARBlock(nn.Module):
    """Pre-LN causal self-attention + MLP, no conditioning."""

    def __init__(self, dim: int, n_heads: int, mlp_ratio: int = 4, dropout: float = 0.0):
        super().__init__()
        assert dim % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.dropout = dropout

        self.norm1 = nn.LayerNorm(dim)
        self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.attn_out = nn.Linear(dim, dim, bias=False)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio * dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio * dim, dim, bias=True),
        )

    def _qkv(self, x: torch.Tensor, rotary_cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        h = self.norm1(x)
        qkv = self.attn_qkv(h)
        qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
        cos, sin = rotary_cos_sin
        qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
        q = qkv[:, :, 0].transpose(1, 2)  # [B, H, S, D]
        k = qkv[:, :, 1].transpose(1, 2)
        v = qkv[:, :, 2].transpose(1, 2)
        return q, k, v

    def forward(self, x: torch.Tensor, rotary_cos_sin) -> torch.Tensor:
        """Uncached path (training / full-sequence eval)."""
        q, k, v = self._qkv(x, rotary_cos_sin)
        attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        attn = rearrange(attn, "b h s d -> b s (h d)")
        x = x + F.dropout(self.attn_out(attn), p=self.dropout, training=self.training)
        x = x + F.dropout(self.mlp(self.norm2(x)), p=self.dropout, training=self.training)
        return x

    def forward_cached(
        self,
        x: torch.Tensor,
        rotary_cos_sin,
        past_kv: Optional[KVPair] = None,
    ) -> Tuple[torch.Tensor, KVPair]:
        """
        Cached path (generation).

        past_kv: optional (k_cache, v_cache) each [B, H, S_cache, D].
        Returns (out [B, S_new, d], new_kv = (k_all, v_all)).

        With S_cache == 0 (first call), acts like an is_causal=True forward.
        With S_cache > 0, expects S_new == 1 (single-step append) — the new
        query at the last position attends to all S_cache + 1 tokens, no mask.
        """
        q, k_new, v_new = self._qkv(x, rotary_cos_sin)

        if past_kv is None or past_kv[0].size(2) == 0:
            k = k_new
            v = v_new
            attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            pk, pv = past_kv
            k = torch.cat([pk, k_new], dim=2)
            v = torch.cat([pv, v_new], dim=2)
            # Single-step append: new query is the most recent position → full
            # visibility over [0 .. S_cache] is correct (no causal mask needed).
            assert q.size(2) == 1, (
                f"forward_cached with non-empty cache expects S_new == 1, got {q.size(2)}"
            )
            attn = F.scaled_dot_product_attention(q, k, v, is_causal=False)

        attn = rearrange(attn, "b h s d -> b s (h d)")
        x = x + self.attn_out(attn)
        x = x + self.mlp(self.norm2(x))
        return x, (k, v)


class ARModel(nn.Module):
    """
    GPT-2-style decoder-only Transformer with RoPE.

    Args:
        vocab_size:   V
        hidden_size:  d
        n_blocks:     number of transformer blocks
        n_heads:      number of attention heads
        max_seq_len:  max supported sequence length (for RoPE precompute)
        dropout:      dropout inside blocks (0.0 by default, matching SAD)
    """

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int = 768,
        n_blocks: int = 12,
        n_heads: int = 12,
        max_seq_len: int = 512,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.max_seq_len = max_seq_len

        # Round vocab to multiple of 128 (same trick as SADModel for tensor-core friendliness)
        self.rounded_vocab_size = vocab_size + (128 - vocab_size % 128) % 128

        # Input embedding
        self.tok_embed = nn.Embedding(self.rounded_vocab_size, hidden_size)
        nn.init.normal_(self.tok_embed.weight, std=0.02)

        self.rotary_emb = Rotary(hidden_size // n_heads, max_seq_len=max_seq_len)

        self.blocks = nn.ModuleList([
            ARBlock(hidden_size, n_heads, dropout=dropout) for _ in range(n_blocks)
        ])
        self.norm_final = nn.LayerNorm(hidden_size)

        # Decoupled output head to match the Block-AR baseline parameterization.
        self.lm_head = nn.Linear(hidden_size, self.rounded_vocab_size, bias=False)
        nn.init.normal_(self.lm_head.weight, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,  # accepted for API symmetry; unused
    ) -> torch.Tensor:
        """
        Full-sequence forward. Training / full-sequence eval path.

        Args:
            input_ids:  [B, S] int64
        Returns:
            logits:     [B, S, V]  (sliced to true vocab, not rounded)
        """
        x = self.tok_embed(input_ids)  # [B, S, d]
        rotary_cos_sin = self.rotary_emb(x)

        for block in self.blocks:
            x = block(x, rotary_cos_sin)

        x = self.norm_final(x)
        logits = self.lm_head(x)[..., :self.vocab_size]
        return logits

    def forward_cached(
        self,
        input_ids: torch.Tensor,
        past_kv_list: Optional[List[KVPair]] = None,
    ) -> Tuple[torch.Tensor, List[KVPair]]:
        """
        Cached forward for generation.

        Args:
            input_ids:     [B, S_new]
            past_kv_list:  None (first call) or list of length n_blocks;
                           each entry is (k, v) of shape [B, H, S_cache, D].
        Returns:
            logits:        [B, S_new, V]
            new_kv_list:   list of length n_blocks with updated (k, v) of
                           shape [B, H, S_cache + S_new, D].
        """
        B, S_new = input_ids.shape
        device = input_ids.device

        S_cache = 0 if past_kv_list is None else past_kv_list[0][0].size(2)
        total_len = S_cache + S_new
        assert total_len <= self.max_seq_len, (
            f"cache+new ({S_cache}+{S_new}={total_len}) exceeds "
            f"max_seq_len={self.max_seq_len}"
        )

        x = self.tok_embed(input_ids)  # [B, S_new, d]

        # RoPE positions for the new tokens: [S_cache .. S_cache + S_new - 1]
        position_ids = torch.arange(S_cache, S_cache + S_new, device=device)
        rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids)

        new_kv_list: List[KVPair] = []
        for i, block in enumerate(self.blocks):
            pkv = None if past_kv_list is None else past_kv_list[i]
            x, new_kv = block.forward_cached(x, rotary_cos_sin, past_kv=pkv)
            new_kv_list.append(new_kv)

        x = self.norm_final(x)
        logits = self.lm_head(x)[..., :self.vocab_size]
        return logits, new_kv_list