sad / src /models /sad_model.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
21.5 kB
"""
SAD backbone model – Block-AR variant with vectorized training.
Architecture:
- Sequence divided into blocks of `block_size` tokens
- Within a block: bidirectional attention
- Across blocks: autoregressive (each block sees only earlier blocks' clean tokens)
- No timestep conditioning; single learnable cond_bias vector
- Input: continuous embeddings (leaf embs / prototype embs / mask emb)
Vectorized training (forward_vectorized):
Concatenate x_full = [x_noisy (0..L-1) | x_clean (L..2L-1)], shape [B, 2L, d].
Apply block-diff attention mask (see build_block_diff_mask), run one forward pass,
return logits for the noisy half only [B, L, V].
Block-diff mask rules (positions in x_full = [noisy | clean]):
noisy[i] → noisy[j]: allowed iff same block
noisy[i] → clean[j]: allowed iff block(j) < block(i) (strictly earlier clean block)
clean[i] → clean[j]: allowed iff block(j) <= block(i) (same or earlier clean block)
clean[i] → noisy[j]: never
Flex attention is used when available for the sparse block-diff mask.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .dit_components import (
DDiTBlockWithMask,
DDitFinalLayer,
EmbeddingLayer,
Rotary,
)
# Use FlexAttention via create_block_mask + mask_mod for true block-sparse
# attention. The compiled kernel skips entirely-masked kv-blocks.
try:
from torch.nn.attention.flex_attention import create_block_mask
_has_flex_attention = True
except ImportError:
create_block_mask = None
_has_flex_attention = False
# ──────────────────────────────────────────────────────────────────────────────
# Block-diff attention mask
# ──────────────────────────────────────────────────────────────────────────────
def build_block_diff_mask(seq_len: int, block_size: int, device) -> torch.Tensor:
"""
Build [2L, 2L] bool attention mask for vectorized block-diff training.
x_full layout: [x_noisy (0..L-1) | x_clean (L..2L-1)]
Rules:
noisy->noisy : same block only (block_q == block_k)
noisy->clean : strictly earlier block (block_k < block_q)
clean->clean : same or earlier block (block_k <= block_q)
clean->noisy : never
Returns: mask [2L, 2L], True = allowed to attend
"""
L = seq_len
N = 2 * L
idx = torch.arange(N, device=device)
q = idx[:, None]
k = idx[None, :]
is_clean_q = q >= L
is_clean_k = k >= L
block_q = torch.where(is_clean_q, (q - L) // block_size, q // block_size)
block_k = torch.where(is_clean_k, (k - L) // block_size, k // block_size)
noisy_to_noisy = (~is_clean_q) & (~is_clean_k) & (block_q == block_k)
noisy_to_clean = (~is_clean_q) & is_clean_k & (block_k < block_q)
clean_to_clean = is_clean_q & is_clean_k & (block_k <= block_q)
return noisy_to_noisy | noisy_to_clean | clean_to_clean
def _make_block_diff_mask_mod(seq_len: int, block_size: int):
"""
Return a mask_mod function for flex_attention's create_block_mask implementing
the block-diff mask. Positions 0..L-1 are noisy; L..2L-1 are clean.
Unlike score_mod (which runs per-score), mask_mod is used by create_block_mask
to build a structured BlockMask that the compiled kernel uses to skip
entire kv-blocks — this is where the real sparsity speedup comes from.
"""
L = seq_len
def mask_mod(b, h, q_idx, k_idx):
is_clean_q = q_idx >= L
is_clean_k = k_idx >= L
block_q = torch.where(is_clean_q, (q_idx - L) // block_size, q_idx // block_size)
block_k = torch.where(is_clean_k, (k_idx - L) // block_size, k_idx // block_size)
n2n = (~is_clean_q) & (~is_clean_k) & (block_q == block_k)
n2c = (~is_clean_q) & is_clean_k & (block_k < block_q)
c2c = is_clean_q & is_clean_k & (block_k <= block_q)
return n2n | n2c | c2c
return mask_mod
# ──────────────────────────────────────────────────────────────────────────────
# SADModel
# ──────────────────────────────────────────────────────────────────────────────
class SADModel(nn.Module):
"""
SAD backbone: Block-AR diffusion, no timestep conditioning.
Args:
vocab_size: V – leaf vocabulary size
hidden_size: d – transformer hidden dimension
n_blocks: number of transformer blocks
n_heads: number of attention heads
cond_dim: conditioning dimension for adaLN
max_seq_len: total sequence length (must be divisible by block_size)
block_size: tokens per block
dropout: dropout rate
num_levels: total hierarchy levels (leaf + intermediate, not counting mask)
level_sizes: [V, K1, K2, ...]
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
n_blocks: int = 12,
n_heads: int = 12,
cond_dim: int = 128,
max_seq_len: int = 512,
block_size: int = 8,
dropout: float = 0.0,
num_levels: int = 2,
level_sizes: Optional[list] = None,
tie_weights: bool = True,
# kept for API compat, ignored:
use_aux_head: bool = False,
embed_input_mode: str = "embeddings",
extended_vocab_size: int = 0,
):
super().__init__()
assert max_seq_len % block_size == 0, \
f"max_seq_len ({max_seq_len}) must be divisible by block_size ({block_size})"
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.n_blocks = n_blocks
self.n_heads = n_heads
self.cond_dim = cond_dim
self.max_seq_len = max_seq_len
self.block_size = block_size
self.num_total_blocks = max_seq_len // block_size
self.num_levels = num_levels
self.level_sizes = level_sizes
self.aux_head = None
# ---- Leaf token embedding table ----
raw_size = vocab_size
self.rounded_vocab_size = raw_size + (128 - raw_size % 128) % 128
self.vocab_embed = EmbeddingLayer(hidden_size, self.rounded_vocab_size)
# ---- Input projection (hidden_size -> hidden_size) ----
self.input_proj = nn.Linear(hidden_size, hidden_size, bias=False)
# ---- Conditioning: single learnable bias, no t ----
self.cond_bias = nn.Parameter(torch.zeros(cond_dim))
nn.init.normal_(self.cond_bias, std=0.02)
# ---- Block-index embedding (AR position across blocks) ----
# Used for both halves; block index is the same for noisy and clean halves.
self.block_idx_embed = nn.Embedding(self.num_total_blocks, hidden_size)
nn.init.normal_(self.block_idx_embed.weight, std=0.02)
# ---- Intra-block position embedding ----
self.intra_pos_embed = nn.Embedding(block_size, hidden_size)
nn.init.normal_(self.intra_pos_embed.weight, std=0.02)
# ---- Noisy/clean segment embedding ----
# Tells the model whether a position is in the noisy half or clean half.
self.segment_embed = nn.Embedding(2, hidden_size) # 0=noisy, 1=clean
nn.init.normal_(self.segment_embed.weight, std=0.02)
# ---- Rotary (intra-block positions only, repeated) ----
self.rotary_emb = Rotary(
hidden_size // n_heads,
max_seq_len=max_seq_len * 2, # 2L for the full vectorized sequence
)
# ---- Transformer blocks ----
self.blocks = nn.ModuleList([
DDiTBlockWithMask(hidden_size, n_heads, cond_dim, dropout=dropout)
for _ in range(n_blocks)
])
# ---- Output head ----
rounded_leaf = vocab_size + (128 - vocab_size % 128) % 128
self.output_layer = DDitFinalLayer(hidden_size, rounded_leaf, cond_dim)
if tie_weights:
# Weight tying: output projection shares the input embedding matrix,
# so logits = h @ vocab_embed.embedding.T (same parameter, same gradients).
self.output_layer.linear.weight = self.vocab_embed.embedding
else:
# Decoupled: independent output projection.
# DDitFinalLayer zero-initialises its linear layer; re-init for trainability.
nn.init.normal_(self.output_layer.linear.weight, std=0.02)
# Cache block-diff masks (keyed by (seq_len, block_size))
self._mask_cache: dict = {}
def get_leaf_embeddings(self) -> torch.Tensor:
"""[V, hidden_size] embedding matrix for leaf tokens."""
return self.vocab_embed.embedding[:self.vocab_size]
def _get_block_diff_mask(self, seq_len: int, device) -> torch.Tensor:
key = (seq_len, self.block_size, str(device))
if key not in self._mask_cache:
self._mask_cache[key] = build_block_diff_mask(seq_len, self.block_size, device)
return self._mask_cache[key]
def _get_flex_block_mask(self, seq_len: int, device):
"""Cached flex BlockMask for (L, block_size, device)."""
key = ("flex", seq_len, self.block_size, str(device))
if key not in self._mask_cache:
mask_mod = _make_block_diff_mask_mod(seq_len, self.block_size)
self._mask_cache[key] = create_block_mask(
mask_mod, B=None, H=None,
Q_LEN=2 * seq_len, KV_LEN=2 * seq_len,
device=device,
)
return self._mask_cache[key]
def _build_position_embeddings(
self, seq_len: int, is_clean_half: bool, device, dtype
) -> torch.Tensor:
"""
Build [1, seq_len, hidden_size] position + segment embeddings.
seq_len = L (one half of the vectorized sequence).
"""
num_blocks = seq_len // self.block_size
# Block indices: [0,0,...,0, 1,1,...,1, ..., nb-1,...,nb-1]
block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size)
# Intra-block positions: [0,1,...,bs-1, 0,1,...,bs-1, ...]
intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks)
seg_id = torch.ones(seq_len, dtype=torch.long, device=device) * int(is_clean_half)
pos_emb = (
self.block_idx_embed(block_idx) # [L, d]
+ self.intra_pos_embed(intra_pos) # [L, d]
+ self.segment_embed(seg_id) # [L, d]
).unsqueeze(0).to(dtype) # [1, L, d]
return pos_emb
def forward_vectorized(
self,
noisy_embs: torch.Tensor,
clean_embs: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Single-forward vectorized block-diff training pass.
Concatenates [noisy_embs | clean_embs] → x_full [B, 2L, d],
applies block-diff sparse attention mask, returns logits for noisy half.
Args:
noisy_embs: [B, L, d] corrupted token embeddings
clean_embs: [B, L, d] original clean token embeddings
attention_mask: [B, L] float/bool padding mask (optional)
Returns:
leaf_logits: [B, L, V]
"""
B, L, d = noisy_embs.shape
device = noisy_embs.device
dtype = noisy_embs.dtype
# ---- Project embeddings to hidden size ----
x_noisy = self.input_proj(noisy_embs) # [B, L, d]
x_clean = self.input_proj(clean_embs) # [B, L, d]
# ---- Add position + segment embeddings ----
x_noisy = x_noisy + self._build_position_embeddings(L, False, device, dtype)
x_clean = x_clean + self._build_position_embeddings(L, True, device, dtype)
# ---- Concatenate ----
x = torch.cat([x_noisy, x_clean], dim=1) # [B, 2L, d]
# ---- Conditioning ----
c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype) # [B, cond_dim]
# ---- Attention mask ----
# Prefer flex (block-sparse, compiled) when available. Padding at attention
# level is ignored — pad positions carry the eos token embedding (a valid
# token) and loss already masks pad positions, so attending to them is
# harmless and lets us keep a static BlockMask for compile.
# Precondition: pad_token_id == eos_token_id (checked in train script).
# TODO: to support multi-doc packing, extend the mask_mod in
# _make_block_diff_mask_mod to AND with a same-document check
# (doc_ids[q_idx] == doc_ids[k_idx]) — this requires threading doc_ids
# into forward_vectorized and rebuilding the flex BlockMask per batch.
if _has_flex_attention:
flex_block_mask = self._get_flex_block_mask(L, device)
attn_mask = None
else:
flex_block_mask = None
attn_mask = self._get_block_diff_mask(L, device) # [2L, 2L] bool
if attention_mask is not None:
pad_mask = torch.cat([attention_mask, attention_mask], dim=1).bool()
block_diff_float = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, -1e9)
attn_mask = block_diff_float[None, None, :, :].expand(B, 1, 2 * L, 2 * L).contiguous()
attn_mask = attn_mask.masked_fill(~pad_mask[:, None, None, :], -1e9)
# ---- Rotary ----
# Align RoPE positions: noisy half and clean half both use 0..L-1
# so that cross-attention between corresponding tokens sees the same
# absolute position (zero relative offset from RoPE).
position_ids = torch.cat([
torch.arange(L, device=device),
torch.arange(L, device=device),
]) # [2L]
rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids)
# ---- Transformer ----
for block in self.blocks:
x = block(
x, rotary_cos_sin, c,
attention_mask=attn_mask,
flex_block_mask=flex_block_mask,
)
# ---- Output: noisy half only (slice BEFORE output projection to avoid
# running the big V×d matmul on the clean half we would discard anyway).
x_noisy_out = x[:, :L] # [B, L, d]
leaf_logits = self.output_layer(x_noisy_out, c)[..., :self.vocab_size] # [B, L, V]
return leaf_logits
def _get_block_causal_mask(self, seq_len: int, device) -> torch.Tensor:
"""
Block-wise causal attention mask for inference.
Token i attends to token j iff block(j) <= block(i).
Within a block: bidirectional. Across blocks: autoregressive.
Returns [seq_len, seq_len] bool mask (True = can attend).
"""
key = ("causal", seq_len, self.block_size, str(device))
if key not in self._mask_cache:
idx = torch.arange(seq_len, device=device)
block_i = idx[:, None] // self.block_size # [S, 1]
block_j = idx[None, :] // self.block_size # [1, S]
self._mask_cache[key] = (block_j <= block_i) # [S, S]
return self._mask_cache[key]
def forward_causal(
self,
input_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Block-wise causal forward pass for Block-AR inference.
Identical to forward() but uses a block-wise causal attention mask so
that each block only sees itself (bidirectionally) and earlier blocks.
Used block-by-block during generation so that previously resolved blocks
provide clean causal context for the current block.
Args:
input_embeddings: [B, S, d] (resolved leaf embs + current block state)
attention_mask: [B, S] padding mask (optional)
Returns:
leaf_logits: [B, S, V]
h: [B, S, d]
"""
B, S, d = input_embeddings.shape
device = input_embeddings.device
dtype = self.vocab_embed.embedding.dtype
x = self.input_proj(input_embeddings.to(dtype))
# Position embeddings (same as forward(), noisy segment id=0)
num_blocks = S // self.block_size
block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size)
intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks)
seg_id = torch.zeros(S, dtype=torch.long, device=device)
pos_emb = (
self.block_idx_embed(block_idx)
+ self.intra_pos_embed(intra_pos)
+ self.segment_embed(seg_id)
).unsqueeze(0).to(dtype)
x = x + pos_emb
c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype)
rotary_cos_sin = self.rotary_emb(x)
# Block-wise causal attention mask: [S, S] → broadcast to [B, S, S]
causal_mask = self._get_block_causal_mask(S, device) # [S, S] bool
dtype_mask = causal_mask.to(dtype).masked_fill(~causal_mask, -1e9) # [S, S]
# Expand to [B, S, S] if needed by transformer blocks
attn_mask = dtype_mask.unsqueeze(0).expand(B, -1, -1) # [B, S, S]
# Merge padding mask if provided
if attention_mask is not None:
pad_bias = torch.zeros(B, S, S, device=device, dtype=dtype)
pad_bias.masked_fill_(~attention_mask.bool().unsqueeze(1), -1e9)
attn_mask = attn_mask + pad_bias
for block in self.blocks:
x = block(x, rotary_cos_sin, c, attention_mask=attn_mask)
h = x
logits_full = self.output_layer(h, c)
leaf_logits = logits_full[..., :self.vocab_size]
return leaf_logits, h
def forward(
self,
input_embeddings: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
# Block-AR training path dispatch (must go through .forward() for DDP grad sync):
noisy_embs: Optional[torch.Tensor] = None,
clean_embs: Optional[torch.Tensor] = None,
# API compat, ignored:
input_ids=None, t=None, level_ids=None,
):
"""
Standard forward (full sequence, all-visible attention).
Used for validation / evaluation.
If `clean_embs` is provided (Block-AR training), dispatches to
`forward_vectorized(noisy_embs, clean_embs, attention_mask)` and
returns leaf_logits only — this routing is required for DDP to
register the forward call and synchronize gradients.
Args:
input_embeddings: [B, S, d] continuous token embeddings
(leaf emb / prototype emb / mask emb, same as training)
attention_mask: [B, S]
Returns:
(leaf_logits, h) in standard mode
leaf_logits in block-AR dispatch mode
"""
if clean_embs is not None:
# Block-AR training dispatch — routes through DDP.
noisy = noisy_embs if noisy_embs is not None else input_embeddings
assert noisy is not None, "forward_vectorized dispatch requires noisy_embs"
return self.forward_vectorized(noisy, clean_embs, attention_mask=attention_mask)
B, S, d = input_embeddings.shape
device = input_embeddings.device
dtype = self.vocab_embed.embedding.dtype
x = self.input_proj(input_embeddings.to(dtype))
# Position embeddings (treat as single noisy sequence)
num_blocks = S // self.block_size
block_idx = torch.arange(num_blocks, device=device).repeat_interleave(self.block_size)
intra_pos = torch.arange(self.block_size, device=device).repeat(num_blocks)
seg_id = torch.zeros(S, dtype=torch.long, device=device)
pos_emb = (
self.block_idx_embed(block_idx)
+ self.intra_pos_embed(intra_pos)
+ self.segment_embed(seg_id)
).unsqueeze(0).to(dtype)
x = x + pos_emb
c = self.cond_bias.unsqueeze(0).expand(B, -1).to(dtype)
rotary_cos_sin = self.rotary_emb(x)
# Pre-build a [B, 1, 1, S] additive float bias for the [B, S] padding
# mask. Downstream DDiT blocks' 2D branch is for [S, S] block-diff masks
# and can't disambiguate a [B, S] padding mask — we'd get a [1,1,B,S]
# broken mask. Passing a float bias routes through the is_floating_point
# branch instead, which broadcasts correctly.
attn_bias = attention_mask
if attn_bias is not None and not attn_bias.is_floating_point() and attn_bias.dim() == 2:
attn_bias = (~attn_bias.bool())[:, None, None, :].to(dtype) * -1e9 # [B, 1, 1, S]
for block in self.blocks:
x = block(x, rotary_cos_sin, c, attention_mask=attn_bias)
h = x # [B, S, d] — hidden state before output layer
logits_full = self.output_layer(h, c)
leaf_logits = logits_full[..., :self.vocab_size]
return leaf_logits, h