| """ |
| SAD backbone model – Block-AR variant with vectorized training. |
| |
| Architecture: |
| - Sequence divided into blocks of `block_size` tokens |
| - Within a block: bidirectional attention |
| - Across blocks: autoregressive (each block sees only earlier blocks' clean tokens) |
| - No timestep conditioning; single learnable cond_bias vector |
| - Input: continuous embeddings (leaf embs / prototype embs / mask emb) |
| |
| Vectorized training (forward_vectorized): |
| Concatenate x_full = [x_noisy (0..L-1) | x_clean (L..2L-1)], shape [B, 2L, d]. |
| Apply block-diff attention mask (see build_block_diff_mask), run one forward pass, |
| return logits for the noisy half only [B, L, V]. |
| |
| Block-diff mask rules (positions in x_full = [noisy | clean]): |
| noisy[i] → noisy[j]: allowed iff same block |
| noisy[i] → clean[j]: allowed iff block(j) < block(i) (strictly earlier clean block) |
| clean[i] → clean[j]: allowed iff block(j) <= block(i) (same or earlier clean block) |
| clean[i] → noisy[j]: never |
| |
| Flex attention is used when available for the sparse block-diff mask. |
| """ |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .dit_components import ( |
| DDiTBlockWithMask, |
| DDitFinalLayer, |
| EmbeddingLayer, |
| Rotary, |
| ) |
|
|
| |
| |
| try: |
| from torch.nn.attention.flex_attention import create_block_mask |
| _has_flex_attention = True |
| except ImportError: |
| create_block_mask = None |
| _has_flex_attention = False |
|
|
|
|
| |
| |
| |
|
|
| def build_block_diff_mask(seq_len: int, block_size: int, device) -> torch.Tensor: |
| """ |
| Build [2L, 2L] bool attention mask for vectorized block-diff training. |
| |
| x_full layout: [x_noisy (0..L-1) | x_clean (L..2L-1)] |
| |
| Rules: |
| noisy->noisy : same block only (block_q == block_k) |
| noisy->clean : strictly earlier block (block_k < block_q) |
| clean->clean : same or earlier block (block_k <= block_q) |
| clean->noisy : never |
| |
| Returns: mask [2L, 2L], True = allowed to attend |
| """ |
| L = seq_len |
| N = 2 * L |
| idx = torch.arange(N, device=device) |
| q = idx[:, None] |
| k = idx[None, :] |
|
|
| is_clean_q = q >= L |
| is_clean_k = k >= L |
|
|
| block_q = torch.where(is_clean_q, (q - L) // block_size, q // block_size) |
| block_k = torch.where(is_clean_k, (k - L) // block_size, k // block_size) |
|
|
| noisy_to_noisy = (~is_clean_q) & (~is_clean_k) & (block_q == block_k) |
| noisy_to_clean = (~is_clean_q) & is_clean_k & (block_k < block_q) |
| clean_to_clean = is_clean_q & is_clean_k & (block_k <= block_q) |
|
|
| return noisy_to_noisy | noisy_to_clean | clean_to_clean |
|
|
|
|
| def _make_block_diff_mask_mod(seq_len: int, block_size: int): |
| """ |
| Return a mask_mod function for flex_attention's create_block_mask implementing |
| the block-diff mask. Positions 0..L-1 are noisy; L..2L-1 are clean. |
| |
| Unlike score_mod (which runs per-score), mask_mod is used by create_block_mask |
| to build a structured BlockMask that the compiled kernel uses to skip |
| entire kv-blocks — this is where the real sparsity speedup comes from. |
| """ |
| L = seq_len |
|
|
| def mask_mod(b, h, q_idx, k_idx): |
| is_clean_q = q_idx >= L |
| is_clean_k = k_idx >= L |
| block_q = torch.where(is_clean_q, (q_idx - L) // block_size, q_idx // block_size) |
| block_k = torch.where(is_clean_k, (k_idx - L) // block_size, k_idx // block_size) |
|
|
| n2n = (~is_clean_q) & (~is_clean_k) & (block_q == block_k) |
| n2c = (~is_clean_q) & is_clean_k & (block_k < block_q) |
| c2c = is_clean_q & is_clean_k & (block_k <= block_q) |
| return n2n | n2c | c2c |
|
|
| return mask_mod |
|
|
|
|
| |
| |
| |
|
|
| class SADModel(nn.Module): |
| """ |
| SAD backbone: Block-AR diffusion, no timestep conditioning. |
| |
| Args: |
| vocab_size: V – leaf vocabulary size |
| hidden_size: d – transformer hidden dimension |
| n_blocks: number of transformer blocks |
| n_heads: number of attention heads |
| cond_dim: conditioning dimension for adaLN |
| max_seq_len: total sequence length (must be divisible by block_size) |
| block_size: tokens per block |
| dropout: dropout rate |
| num_levels: total hierarchy levels (leaf + intermediate, not counting mask) |
| level_sizes: [V, K1, K2, ...] |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| hidden_size: int = 768, |
| n_blocks: int = 12, |
| n_heads: int = 12, |
| cond_dim: int = 128, |
| max_seq_len: int = 512, |
| block_size: int = 8, |
| dropout: float = 0.0, |
| num_levels: int = 2, |
| level_sizes: Optional[list] = None, |
| tie_weights: bool = True, |
| |
| use_aux_head: bool = False, |
| embed_input_mode: str = "embeddings", |
| extended_vocab_size: int = 0, |
| ): |
| super().__init__() |
| assert max_seq_len % block_size == 0, \ |
| f"max_seq_len ({max_seq_len}) must be divisible by block_size ({block_size})" |
|
|
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.n_blocks = n_blocks |
| self.n_heads = n_heads |
| self.cond_dim = cond_dim |
| self.max_seq_len = max_seq_len |
| self.block_size = block_size |
| self.num_total_blocks = max_seq_len // block_size |
| self.num_levels = num_levels |
| self.level_sizes = level_sizes |
| self.aux_head = None |
|
|
| |
| raw_size = vocab_size |
| self.rounded_vocab_size = raw_size + (128 - raw_size % 128) % 128 |
| self.vocab_embed = EmbeddingLayer(hidden_size, self.rounded_vocab_size) |
|
|
| |
| self.input_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
| |
| self.cond_bias = nn.Parameter(torch.zeros(cond_dim)) |
| nn.init.normal_(self.cond_bias, std=0.02) |
|
|
| |
| |
| self.block_idx_embed = nn.Embedding(self.num_total_blocks, hidden_size) |
| nn.init.normal_(self.block_idx_embed.weight, std=0.02) |
|
|
| |
| self.intra_pos_embed = nn.Embedding(block_size, hidden_size) |
| nn.init.normal_(self.intra_pos_embed.weight, std=0.02) |
|
|
| |
| |
| self.segment_embed = nn.Embedding(2, hidden_size) |
| nn.init.normal_(self.segment_embed.weight, std=0.02) |
|
|
| |
| self.rotary_emb = Rotary( |
| hidden_size // n_heads, |
| max_seq_len=max_seq_len * 2, |
| ) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| DDiTBlockWithMask(hidden_size, n_heads, cond_dim, dropout=dropout) |
| for _ in range(n_blocks) |
| ]) |
|
|
| |
| rounded_leaf = vocab_size + (128 - vocab_size % 128) % 128 |
| self.output_layer = DDitFinalLayer(hidden_size, rounded_leaf, cond_dim) |
| if tie_weights: |
| |
| |
| self.output_layer.linear.weight = self.vocab_embed.embedding |
| else: |
| |
| |
| nn.init.normal_(self.output_layer.linear.weight, std=0.02) |
|
|
| |
| self._mask_cache: dict = {} |
|
|
| def get_leaf_embeddings(self) -> torch.Tensor: |
| """[V, hidden_size] embedding matrix for leaf tokens.""" |
| return self.vocab_embed.embedding[:self.vocab_size] |
|
|
| def _get_block_diff_mask(self, seq_len: int, device) -> torch.Tensor: |
| key = (seq_len, self.block_size, str(device)) |
| if key not in self._mask_cache: |
| self._mask_cache[key] = build_block_diff_mask(seq_len, self.block_size, device) |
| return self._mask_cache[key] |
|
|
| def _get_flex_block_mask(self, seq_len: int, device): |
| """Cached flex BlockMask for (L, block_size, device).""" |
| key = ("flex", seq_len, self.block_size, str(device)) |
| if key not in self._mask_cache: |
| mask_mod = _make_block_diff_mask_mod(seq_len, self.block_size) |
| self._mask_cache[key] = create_block_mask( |
| mask_mod, B=None, H=None, |
| Q_LEN=2 * seq_len, KV_LEN=2 * seq_len, |
| device=device, |
| ) |
| return self._mask_cache[key] |
|
|
| def _build_position_embeddings( |
| self, seq_len: int, is_clean_half: bool, device, dtype |
| ) -> torch.Tensor: |
| """ |
| Build [1, seq_len, hidden_size] position + segment embeddings. |
| |
| seq_len = L (one half of the vectorized sequence). |
| """ |
| num_blocks = seq_len // self.block_size |
|
|
| |
| block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size) |
| |
| intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks) |
|
|
| seg_id = torch.ones(seq_len, dtype=torch.long, device=device) * int(is_clean_half) |
|
|
| pos_emb = ( |
| self.block_idx_embed(block_idx) |
| + self.intra_pos_embed(intra_pos) |
| + self.segment_embed(seg_id) |
| ).unsqueeze(0).to(dtype) |
| return pos_emb |
|
|
| def forward_vectorized( |
| self, |
| noisy_embs: torch.Tensor, |
| clean_embs: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Single-forward vectorized block-diff training pass. |
| |
| Concatenates [noisy_embs | clean_embs] → x_full [B, 2L, d], |
| applies block-diff sparse attention mask, returns logits for noisy half. |
| |
| Args: |
| noisy_embs: [B, L, d] corrupted token embeddings |
| clean_embs: [B, L, d] original clean token embeddings |
| attention_mask: [B, L] float/bool padding mask (optional) |
| |
| Returns: |
| leaf_logits: [B, L, V] |
| """ |
| B, L, d = noisy_embs.shape |
| device = noisy_embs.device |
| dtype = noisy_embs.dtype |
|
|
| |
| x_noisy = self.input_proj(noisy_embs) |
| x_clean = self.input_proj(clean_embs) |
|
|
| |
| x_noisy = x_noisy + self._build_position_embeddings(L, False, device, dtype) |
| x_clean = x_clean + self._build_position_embeddings(L, True, device, dtype) |
|
|
| |
| x = torch.cat([x_noisy, x_clean], dim=1) |
|
|
| |
| c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if _has_flex_attention: |
| flex_block_mask = self._get_flex_block_mask(L, device) |
| attn_mask = None |
| else: |
| flex_block_mask = None |
| attn_mask = self._get_block_diff_mask(L, device) |
| if attention_mask is not None: |
| pad_mask = torch.cat([attention_mask, attention_mask], dim=1).bool() |
| block_diff_float = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, -1e9) |
| attn_mask = block_diff_float[None, None, :, :].expand(B, 1, 2 * L, 2 * L).contiguous() |
| attn_mask = attn_mask.masked_fill(~pad_mask[:, None, None, :], -1e9) |
|
|
| |
| |
| |
| |
| position_ids = torch.cat([ |
| torch.arange(L, device=device), |
| torch.arange(L, device=device), |
| ]) |
| rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids) |
|
|
| |
| for block in self.blocks: |
| x = block( |
| x, rotary_cos_sin, c, |
| attention_mask=attn_mask, |
| flex_block_mask=flex_block_mask, |
| ) |
|
|
| |
| |
| x_noisy_out = x[:, :L] |
| leaf_logits = self.output_layer(x_noisy_out, c)[..., :self.vocab_size] |
| return leaf_logits |
|
|
| def _get_block_causal_mask(self, seq_len: int, device) -> torch.Tensor: |
| """ |
| Block-wise causal attention mask for inference. |
| Token i attends to token j iff block(j) <= block(i). |
| Within a block: bidirectional. Across blocks: autoregressive. |
| Returns [seq_len, seq_len] bool mask (True = can attend). |
| """ |
| key = ("causal", seq_len, self.block_size, str(device)) |
| if key not in self._mask_cache: |
| idx = torch.arange(seq_len, device=device) |
| block_i = idx[:, None] // self.block_size |
| block_j = idx[None, :] // self.block_size |
| self._mask_cache[key] = (block_j <= block_i) |
| return self._mask_cache[key] |
|
|
| def forward_causal( |
| self, |
| input_embeddings: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Block-wise causal forward pass for Block-AR inference. |
| |
| Identical to forward() but uses a block-wise causal attention mask so |
| that each block only sees itself (bidirectionally) and earlier blocks. |
| Used block-by-block during generation so that previously resolved blocks |
| provide clean causal context for the current block. |
| |
| Args: |
| input_embeddings: [B, S, d] (resolved leaf embs + current block state) |
| attention_mask: [B, S] padding mask (optional) |
| |
| Returns: |
| leaf_logits: [B, S, V] |
| h: [B, S, d] |
| """ |
| B, S, d = input_embeddings.shape |
| device = input_embeddings.device |
| dtype = self.vocab_embed.embedding.dtype |
|
|
| x = self.input_proj(input_embeddings.to(dtype)) |
|
|
| |
| num_blocks = S // self.block_size |
| block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size) |
| intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks) |
| seg_id = torch.zeros(S, dtype=torch.long, device=device) |
| pos_emb = ( |
| self.block_idx_embed(block_idx) |
| + self.intra_pos_embed(intra_pos) |
| + self.segment_embed(seg_id) |
| ).unsqueeze(0).to(dtype) |
| x = x + pos_emb |
|
|
| c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) |
| rotary_cos_sin = self.rotary_emb(x) |
|
|
| |
| causal_mask = self._get_block_causal_mask(S, device) |
| dtype_mask = causal_mask.to(dtype).masked_fill(~causal_mask, -1e9) |
| |
| attn_mask = dtype_mask.unsqueeze(0).expand(B, -1, -1) |
|
|
| |
| if attention_mask is not None: |
| pad_bias = torch.zeros(B, S, S, device=device, dtype=dtype) |
| pad_bias.masked_fill_(~attention_mask.bool().unsqueeze(1), -1e9) |
| attn_mask = attn_mask + pad_bias |
|
|
| for block in self.blocks: |
| x = block(x, rotary_cos_sin, c, attention_mask=attn_mask) |
|
|
| h = x |
| logits_full = self.output_layer(h, c) |
| leaf_logits = logits_full[..., :self.vocab_size] |
| return leaf_logits, h |
|
|
| def forward( |
| self, |
| input_embeddings: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| |
| noisy_embs: Optional[torch.Tensor] = None, |
| clean_embs: Optional[torch.Tensor] = None, |
| |
| input_ids=None, t=None, level_ids=None, |
| ): |
| """ |
| Standard forward (full sequence, all-visible attention). |
| Used for validation / evaluation. |
| |
| If `clean_embs` is provided (Block-AR training), dispatches to |
| `forward_vectorized(noisy_embs, clean_embs, attention_mask)` and |
| returns leaf_logits only — this routing is required for DDP to |
| register the forward call and synchronize gradients. |
| |
| Args: |
| input_embeddings: [B, S, d] continuous token embeddings |
| (leaf emb / prototype emb / mask emb, same as training) |
| attention_mask: [B, S] |
| |
| Returns: |
| (leaf_logits, h) in standard mode |
| leaf_logits in block-AR dispatch mode |
| """ |
| if clean_embs is not None: |
| |
| noisy = noisy_embs if noisy_embs is not None else input_embeddings |
| assert noisy is not None, "forward_vectorized dispatch requires noisy_embs" |
| return self.forward_vectorized(noisy, clean_embs, attention_mask=attention_mask) |
|
|
| B, S, d = input_embeddings.shape |
| device = input_embeddings.device |
| dtype = self.vocab_embed.embedding.dtype |
|
|
| x = self.input_proj(input_embeddings.to(dtype)) |
|
|
| |
| num_blocks = S // self.block_size |
| block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size) |
| intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks) |
| seg_id = torch.zeros(S, dtype=torch.long, device=device) |
| pos_emb = ( |
| self.block_idx_embed(block_idx) |
| + self.intra_pos_embed(intra_pos) |
| + self.segment_embed(seg_id) |
| ).unsqueeze(0).to(dtype) |
| x = x + pos_emb |
|
|
| c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) |
| rotary_cos_sin = self.rotary_emb(x) |
|
|
| |
| |
| |
| |
| |
| attn_bias = attention_mask |
| if attn_bias is not None and not attn_bias.is_floating_point() and attn_bias.dim() == 2: |
| attn_bias = (~attn_bias.bool())[:, None, None, :].to(dtype) * -1e9 |
|
|
| for block in self.blocks: |
| x = block(x, rotary_cos_sin, c, attention_mask=attn_bias) |
|
|
| h = x |
| logits_full = self.output_layer(h, c) |
| leaf_logits = logits_full[..., :self.vocab_size] |
| return leaf_logits, h |
|
|
|
|
|
|