""" SAD backbone model – Block-AR variant with vectorized training. Architecture: - Sequence divided into blocks of `block_size` tokens - Within a block: bidirectional attention - Across blocks: autoregressive (each block sees only earlier blocks' clean tokens) - No timestep conditioning; single learnable cond_bias vector - Input: continuous embeddings (leaf embs / prototype embs / mask emb) Vectorized training (forward_vectorized): Concatenate x_full = [x_noisy (0..L-1) | x_clean (L..2L-1)], shape [B, 2L, d]. Apply block-diff attention mask (see build_block_diff_mask), run one forward pass, return logits for the noisy half only [B, L, V]. Block-diff mask rules (positions in x_full = [noisy | clean]): noisy[i] → noisy[j]: allowed iff same block noisy[i] → clean[j]: allowed iff block(j) < block(i) (strictly earlier clean block) clean[i] → clean[j]: allowed iff block(j) <= block(i) (same or earlier clean block) clean[i] → noisy[j]: never Flex attention is used when available for the sparse block-diff mask. """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from .dit_components import ( DDiTBlockWithMask, DDitFinalLayer, EmbeddingLayer, Rotary, ) # Use FlexAttention via create_block_mask + mask_mod for true block-sparse # attention. The compiled kernel skips entirely-masked kv-blocks. try: from torch.nn.attention.flex_attention import create_block_mask _has_flex_attention = True except ImportError: create_block_mask = None _has_flex_attention = False # ────────────────────────────────────────────────────────────────────────────── # Block-diff attention mask # ────────────────────────────────────────────────────────────────────────────── def build_block_diff_mask(seq_len: int, block_size: int, device) -> torch.Tensor: """ Build [2L, 2L] bool attention mask for vectorized block-diff training. x_full layout: [x_noisy (0..L-1) | x_clean (L..2L-1)] Rules: noisy->noisy : same block only (block_q == block_k) noisy->clean : strictly earlier block (block_k < block_q) clean->clean : same or earlier block (block_k <= block_q) clean->noisy : never Returns: mask [2L, 2L], True = allowed to attend """ L = seq_len N = 2 * L idx = torch.arange(N, device=device) q = idx[:, None] k = idx[None, :] is_clean_q = q >= L is_clean_k = k >= L block_q = torch.where(is_clean_q, (q - L) // block_size, q // block_size) block_k = torch.where(is_clean_k, (k - L) // block_size, k // block_size) noisy_to_noisy = (~is_clean_q) & (~is_clean_k) & (block_q == block_k) noisy_to_clean = (~is_clean_q) & is_clean_k & (block_k < block_q) clean_to_clean = is_clean_q & is_clean_k & (block_k <= block_q) return noisy_to_noisy | noisy_to_clean | clean_to_clean def _make_block_diff_mask_mod(seq_len: int, block_size: int): """ Return a mask_mod function for flex_attention's create_block_mask implementing the block-diff mask. Positions 0..L-1 are noisy; L..2L-1 are clean. Unlike score_mod (which runs per-score), mask_mod is used by create_block_mask to build a structured BlockMask that the compiled kernel uses to skip entire kv-blocks — this is where the real sparsity speedup comes from. """ L = seq_len def mask_mod(b, h, q_idx, k_idx): is_clean_q = q_idx >= L is_clean_k = k_idx >= L block_q = torch.where(is_clean_q, (q_idx - L) // block_size, q_idx // block_size) block_k = torch.where(is_clean_k, (k_idx - L) // block_size, k_idx // block_size) n2n = (~is_clean_q) & (~is_clean_k) & (block_q == block_k) n2c = (~is_clean_q) & is_clean_k & (block_k < block_q) c2c = is_clean_q & is_clean_k & (block_k <= block_q) return n2n | n2c | c2c return mask_mod # ────────────────────────────────────────────────────────────────────────────── # SADModel # ────────────────────────────────────────────────────────────────────────────── class SADModel(nn.Module): """ SAD backbone: Block-AR diffusion, no timestep conditioning. Args: vocab_size: V – leaf vocabulary size hidden_size: d – transformer hidden dimension n_blocks: number of transformer blocks n_heads: number of attention heads cond_dim: conditioning dimension for adaLN max_seq_len: total sequence length (must be divisible by block_size) block_size: tokens per block dropout: dropout rate num_levels: total hierarchy levels (leaf + intermediate, not counting mask) level_sizes: [V, K1, K2, ...] """ def __init__( self, vocab_size: int, hidden_size: int = 768, n_blocks: int = 12, n_heads: int = 12, cond_dim: int = 128, max_seq_len: int = 512, block_size: int = 8, dropout: float = 0.0, num_levels: int = 2, level_sizes: Optional[list] = None, tie_weights: bool = True, # kept for API compat, ignored: use_aux_head: bool = False, embed_input_mode: str = "embeddings", extended_vocab_size: int = 0, ): super().__init__() assert max_seq_len % block_size == 0, \ f"max_seq_len ({max_seq_len}) must be divisible by block_size ({block_size})" self.vocab_size = vocab_size self.hidden_size = hidden_size self.n_blocks = n_blocks self.n_heads = n_heads self.cond_dim = cond_dim self.max_seq_len = max_seq_len self.block_size = block_size self.num_total_blocks = max_seq_len // block_size self.num_levels = num_levels self.level_sizes = level_sizes self.aux_head = None # ---- Leaf token embedding table ---- raw_size = vocab_size self.rounded_vocab_size = raw_size + (128 - raw_size % 128) % 128 self.vocab_embed = EmbeddingLayer(hidden_size, self.rounded_vocab_size) # ---- Input projection (hidden_size -> hidden_size) ---- self.input_proj = nn.Linear(hidden_size, hidden_size, bias=False) # ---- Conditioning: single learnable bias, no t ---- self.cond_bias = nn.Parameter(torch.zeros(cond_dim)) nn.init.normal_(self.cond_bias, std=0.02) # ---- Block-index embedding (AR position across blocks) ---- # Used for both halves; block index is the same for noisy and clean halves. self.block_idx_embed = nn.Embedding(self.num_total_blocks, hidden_size) nn.init.normal_(self.block_idx_embed.weight, std=0.02) # ---- Intra-block position embedding ---- self.intra_pos_embed = nn.Embedding(block_size, hidden_size) nn.init.normal_(self.intra_pos_embed.weight, std=0.02) # ---- Noisy/clean segment embedding ---- # Tells the model whether a position is in the noisy half or clean half. self.segment_embed = nn.Embedding(2, hidden_size) # 0=noisy, 1=clean nn.init.normal_(self.segment_embed.weight, std=0.02) # ---- Rotary (intra-block positions only, repeated) ---- self.rotary_emb = Rotary( hidden_size // n_heads, max_seq_len=max_seq_len * 2, # 2L for the full vectorized sequence ) # ---- Transformer blocks ---- self.blocks = nn.ModuleList([ DDiTBlockWithMask(hidden_size, n_heads, cond_dim, dropout=dropout) for _ in range(n_blocks) ]) # ---- Output head ---- rounded_leaf = vocab_size + (128 - vocab_size % 128) % 128 self.output_layer = DDitFinalLayer(hidden_size, rounded_leaf, cond_dim) if tie_weights: # Weight tying: output projection shares the input embedding matrix, # so logits = h @ vocab_embed.embedding.T (same parameter, same gradients). self.output_layer.linear.weight = self.vocab_embed.embedding else: # Decoupled: independent output projection. # DDitFinalLayer zero-initialises its linear layer; re-init for trainability. nn.init.normal_(self.output_layer.linear.weight, std=0.02) # Cache block-diff masks (keyed by (seq_len, block_size)) self._mask_cache: dict = {} def get_leaf_embeddings(self) -> torch.Tensor: """[V, hidden_size] embedding matrix for leaf tokens.""" return self.vocab_embed.embedding[:self.vocab_size] def _get_block_diff_mask(self, seq_len: int, device) -> torch.Tensor: key = (seq_len, self.block_size, str(device)) if key not in self._mask_cache: self._mask_cache[key] = build_block_diff_mask(seq_len, self.block_size, device) return self._mask_cache[key] def _get_flex_block_mask(self, seq_len: int, device): """Cached flex BlockMask for (L, block_size, device).""" key = ("flex", seq_len, self.block_size, str(device)) if key not in self._mask_cache: mask_mod = _make_block_diff_mask_mod(seq_len, self.block_size) self._mask_cache[key] = create_block_mask( mask_mod, B=None, H=None, Q_LEN=2 * seq_len, KV_LEN=2 * seq_len, device=device, ) return self._mask_cache[key] def _build_position_embeddings( self, seq_len: int, is_clean_half: bool, device, dtype ) -> torch.Tensor: """ Build [1, seq_len, hidden_size] position + segment embeddings. seq_len = L (one half of the vectorized sequence). """ num_blocks = seq_len // self.block_size # Block indices: [0,0,...,0, 1,1,...,1, ..., nb-1,...,nb-1] block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size) # Intra-block positions: [0,1,...,bs-1, 0,1,...,bs-1, ...] intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks) seg_id = torch.ones(seq_len, dtype=torch.long, device=device) * int(is_clean_half) pos_emb = ( self.block_idx_embed(block_idx) # [L, d] + self.intra_pos_embed(intra_pos) # [L, d] + self.segment_embed(seg_id) # [L, d] ).unsqueeze(0).to(dtype) # [1, L, d] return pos_emb def forward_vectorized( self, noisy_embs: torch.Tensor, clean_embs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Single-forward vectorized block-diff training pass. Concatenates [noisy_embs | clean_embs] → x_full [B, 2L, d], applies block-diff sparse attention mask, returns logits for noisy half. Args: noisy_embs: [B, L, d] corrupted token embeddings clean_embs: [B, L, d] original clean token embeddings attention_mask: [B, L] float/bool padding mask (optional) Returns: leaf_logits: [B, L, V] """ B, L, d = noisy_embs.shape device = noisy_embs.device dtype = noisy_embs.dtype # ---- Project embeddings to hidden size ---- x_noisy = self.input_proj(noisy_embs) # [B, L, d] x_clean = self.input_proj(clean_embs) # [B, L, d] # ---- Add position + segment embeddings ---- x_noisy = x_noisy + self._build_position_embeddings(L, False, device, dtype) x_clean = x_clean + self._build_position_embeddings(L, True, device, dtype) # ---- Concatenate ---- x = torch.cat([x_noisy, x_clean], dim=1) # [B, 2L, d] # ---- Conditioning ---- c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) # [B, cond_dim] # ---- Attention mask ---- # Prefer flex (block-sparse, compiled) when available. Padding at attention # level is ignored — pad positions carry the eos token embedding (a valid # token) and loss already masks pad positions, so attending to them is # harmless and lets us keep a static BlockMask for compile. # Precondition: pad_token_id == eos_token_id (checked in train script). # TODO: to support multi-doc packing, extend the mask_mod in # _make_block_diff_mask_mod to AND with a same-document check # (doc_ids[q_idx] == doc_ids[k_idx]) — this requires threading doc_ids # into forward_vectorized and rebuilding the flex BlockMask per batch. if _has_flex_attention: flex_block_mask = self._get_flex_block_mask(L, device) attn_mask = None else: flex_block_mask = None attn_mask = self._get_block_diff_mask(L, device) # [2L, 2L] bool if attention_mask is not None: pad_mask = torch.cat([attention_mask, attention_mask], dim=1).bool() block_diff_float = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, -1e9) attn_mask = block_diff_float[None, None, :, :].expand(B, 1, 2 * L, 2 * L).contiguous() attn_mask = attn_mask.masked_fill(~pad_mask[:, None, None, :], -1e9) # ---- Rotary ---- # Align RoPE positions: noisy half and clean half both use 0..L-1 # so that cross-attention between corresponding tokens sees the same # absolute position (zero relative offset from RoPE). position_ids = torch.cat([ torch.arange(L, device=device), torch.arange(L, device=device), ]) # [2L] rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids) # ---- Transformer ---- for block in self.blocks: x = block( x, rotary_cos_sin, c, attention_mask=attn_mask, flex_block_mask=flex_block_mask, ) # ---- Output: noisy half only (slice BEFORE output projection to avoid # running the big V×d matmul on the clean half we would discard anyway). x_noisy_out = x[:, :L] # [B, L, d] leaf_logits = self.output_layer(x_noisy_out, c)[..., :self.vocab_size] # [B, L, V] return leaf_logits def _get_block_causal_mask(self, seq_len: int, device) -> torch.Tensor: """ Block-wise causal attention mask for inference. Token i attends to token j iff block(j) <= block(i). Within a block: bidirectional. Across blocks: autoregressive. Returns [seq_len, seq_len] bool mask (True = can attend). """ key = ("causal", seq_len, self.block_size, str(device)) if key not in self._mask_cache: idx = torch.arange(seq_len, device=device) block_i = idx[:, None] // self.block_size # [S, 1] block_j = idx[None, :] // self.block_size # [1, S] self._mask_cache[key] = (block_j <= block_i) # [S, S] return self._mask_cache[key] def forward_causal( self, input_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Block-wise causal forward pass for Block-AR inference. Identical to forward() but uses a block-wise causal attention mask so that each block only sees itself (bidirectionally) and earlier blocks. Used block-by-block during generation so that previously resolved blocks provide clean causal context for the current block. Args: input_embeddings: [B, S, d] (resolved leaf embs + current block state) attention_mask: [B, S] padding mask (optional) Returns: leaf_logits: [B, S, V] h: [B, S, d] """ B, S, d = input_embeddings.shape device = input_embeddings.device dtype = self.vocab_embed.embedding.dtype x = self.input_proj(input_embeddings.to(dtype)) # Position embeddings (same as forward(), noisy segment id=0) num_blocks = S // self.block_size block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size) intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks) seg_id = torch.zeros(S, dtype=torch.long, device=device) pos_emb = ( self.block_idx_embed(block_idx) + self.intra_pos_embed(intra_pos) + self.segment_embed(seg_id) ).unsqueeze(0).to(dtype) x = x + pos_emb c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) rotary_cos_sin = self.rotary_emb(x) # Block-wise causal attention mask: [S, S] → broadcast to [B, S, S] causal_mask = self._get_block_causal_mask(S, device) # [S, S] bool dtype_mask = causal_mask.to(dtype).masked_fill(~causal_mask, -1e9) # [S, S] # Expand to [B, S, S] if needed by transformer blocks attn_mask = dtype_mask.unsqueeze(0).expand(B, -1, -1) # [B, S, S] # Merge padding mask if provided if attention_mask is not None: pad_bias = torch.zeros(B, S, S, device=device, dtype=dtype) pad_bias.masked_fill_(~attention_mask.bool().unsqueeze(1), -1e9) attn_mask = attn_mask + pad_bias for block in self.blocks: x = block(x, rotary_cos_sin, c, attention_mask=attn_mask) h = x logits_full = self.output_layer(h, c) leaf_logits = logits_full[..., :self.vocab_size] return leaf_logits, h def forward( self, input_embeddings: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, # Block-AR training path dispatch (must go through .forward() for DDP grad sync): noisy_embs: Optional[torch.Tensor] = None, clean_embs: Optional[torch.Tensor] = None, # API compat, ignored: input_ids=None, t=None, level_ids=None, ): """ Standard forward (full sequence, all-visible attention). Used for validation / evaluation. If `clean_embs` is provided (Block-AR training), dispatches to `forward_vectorized(noisy_embs, clean_embs, attention_mask)` and returns leaf_logits only — this routing is required for DDP to register the forward call and synchronize gradients. Args: input_embeddings: [B, S, d] continuous token embeddings (leaf emb / prototype emb / mask emb, same as training) attention_mask: [B, S] Returns: (leaf_logits, h) in standard mode leaf_logits in block-AR dispatch mode """ if clean_embs is not None: # Block-AR training dispatch — routes through DDP. noisy = noisy_embs if noisy_embs is not None else input_embeddings assert noisy is not None, "forward_vectorized dispatch requires noisy_embs" return self.forward_vectorized(noisy, clean_embs, attention_mask=attention_mask) B, S, d = input_embeddings.shape device = input_embeddings.device dtype = self.vocab_embed.embedding.dtype x = self.input_proj(input_embeddings.to(dtype)) # Position embeddings (treat as single noisy sequence) num_blocks = S // self.block_size block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size) intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks) seg_id = torch.zeros(S, dtype=torch.long, device=device) pos_emb = ( self.block_idx_embed(block_idx) + self.intra_pos_embed(intra_pos) + self.segment_embed(seg_id) ).unsqueeze(0).to(dtype) x = x + pos_emb c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) rotary_cos_sin = self.rotary_emb(x) # Pre-build a [B, 1, 1, S] additive float bias for the [B, S] padding # mask. Downstream DDiT blocks' 2D branch is for [S, S] block-diff masks # and can't disambiguate a [B, S] padding mask — we'd get a [1,1,B,S] # broken mask. Passing a float bias routes through the is_floating_point # branch instead, which broadcasts correctly. attn_bias = attention_mask if attn_bias is not None and not attn_bias.is_floating_point() and attn_bias.dim() == 2: attn_bias = (~attn_bias.bool())[:, None, None, :].to(dtype) * -1e9 # [B, 1, 1, S] for block in self.blocks: x = block(x, rotary_cos_sin, c, attention_mask=attn_bias) h = x # [B, S, d] — hidden state before output layer logits_full = self.output_layer(h, c) leaf_logits = logits_full[..., :self.vocab_size] return leaf_logits, h