modell-name / src /models /tidar.py
RabidUmarell's picture
Add model checkpoint and source
8006486 verified
"""TiDAR: Think in Diffusion, Talk in Autoregression.
Reference: Liu et al., arXiv:2511.08923
Training sequence structure (block_size=B, prefix length=T, total=2T):
[ x_0, x_1, ..., x_{T-1} | M, M, ..., M ]
← clean prefix (AR) β†’ ← mask section β†’
Structured attention mask
─────────────────────────
β€’ Clean prefix [0 : T]: causal (standard lower-triangular)
β€’ Mask section [T : 2T]: full attention to clean prefix
+ bidirectional within each B-token block
+ causal between blocks
Loss
────
β€’ AR loss (L_AR): computed externally by train.py on model output[:, :T, :]
β€’ Diffusion loss (L_Diff): model predicts the original token at each mask position;
stored in self.aux_loss during training.
β€’ Combined: L = (Ξ±Β·L_AR + L_Diff) / (1 + Ξ±) [paper eq., Ξ±=1 default]
β€” train.py adds aux_loss directly to the primary criterion output.
Interface
─────────
forward(x: Tensor[B, T, d_input]) β†’ Tensor[B, T, d_output] (AR logits only)
self.aux_loss: scalar Tensor (diffusion CE, populated during training)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Attention mask
# ---------------------------------------------------------------------------
def _build_tidar_mask(T: int, block_size: int, device: torch.device) -> torch.Tensor:
"""Return a (2T, 2T) additive float mask: 0.0 where allowed, -inf where blocked."""
S = 2 * T
mask = torch.full((S, S), float("-inf"), device=device)
idx = torch.arange(T, device=device) # (T,)
# ── top-left [0:T, 0:T]: causal self-attention for clean prefix ──────────
# mask[i, j] = 0 iff j ≀ i
causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
mask[:T, :T] = causal.float().masked_fill(~causal, float("-inf")).masked_fill(causal, 0.0)
# ── top-right [0:T, T:2T]: prefix never attends to mask tokens ───────────
# (already -inf from initialization)
# ── bottom-left [T:2T, 0:T]: mask tokens see the entire clean prefix ─────
mask[T:, :T] = 0.0
# ── bottom-right [T:2T, T:2T]: block-causal + intra-block bidirectional ──
# mask[T+i, T+j] = 0 iff block(j) ≀ block(i)
bi = idx // block_size # (T,)
allowed = bi.unsqueeze(1) >= bi.unsqueeze(0) # allowed[i, j] = (block_i >= block_j)
mask[T:, T:] = allowed.float().masked_fill(~allowed, float("-inf")).masked_fill(allowed, 0.0)
return mask
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class _MaskedSelfAttention(nn.Module):
"""Multi-head self-attention with an explicit additive attention mask."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.attn_drop_p = dropout
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
q, k, v = self.qkv(x).split(C, dim=-1)
def split_heads(t: torch.Tensor) -> torch.Tensor:
return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
out = F.scaled_dot_product_attention(
split_heads(q), split_heads(k), split_heads(v),
attn_mask=attn_mask,
dropout_p=self.attn_drop_p if self.training else 0.0,
)
return self.proj(out.transpose(1, 2).contiguous().view(B, T, C))
class _TiDARBlock(nn.Module):
"""Pre-LN Transformer block accepting an explicit additive attention mask."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = _MaskedSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, 4 * d_model), nn.GELU(),
nn.Linear(4 * d_model, d_model), nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x), attn_mask)
x = x + self.mlp(self.ln2(x))
return x
# ---------------------------------------------------------------------------
# TiDAR model
# ---------------------------------------------------------------------------
class TiDARModel(nn.Module):
"""TiDAR: Think in Diffusion, Talk in Autoregression (arXiv:2511.08923).
Doubles the sequence internally β€” the second half is a block of learned
[MASK] embeddings processed with the TiDAR structured attention mask.
Returns AR logits for the clean prefix only; diffusion auxiliary loss is
stored in self.aux_loss during training.
"""
def __init__(
self,
d_input: int,
d_model: int,
d_output: int,
n_layers: int = 2,
n_heads: int = 4,
block_size: int = 8,
alpha: float = 1.0,
max_len: int = 4096,
dropout: float = 0.0,
**kwargs,
):
super().__init__()
self.block_size = block_size
self.alpha = alpha
self.input_proj = nn.Linear(d_input, d_model)
# Learned [MASK] embedding shared across all mask positions (bias absorbed by pos_emb)
self.mask_emb = nn.Parameter(torch.empty(1, 1, d_model))
nn.init.normal_(self.mask_emb, std=0.02)
self.pos_emb = nn.Embedding(2 * max_len, d_model)
self.blocks = nn.ModuleList([
_TiDARBlock(d_model, n_heads, dropout) for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, d_output)
# Populated each forward call during training; zero otherwise
self.aux_loss: torch.Tensor = torch.tensor(0.0)
# Cache attention masks by (T, device) to avoid recomputing
self._mask_cache: dict[tuple, torch.Tensor] = {}
# ------------------------------------------------------------------
def _get_attn_mask(self, T: int, device: torch.device) -> torch.Tensor:
key = (T, device.type, getattr(device, "index", 0))
if key not in self._mask_cache:
self._mask_cache[key] = _build_tidar_mask(T, self.block_size, device)
return self._mask_cache[key]
# ------------------------------------------------------------------
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""(B, T, d_input) β†’ (B, T, d_output) [AR logits for clean prefix]
self.aux_loss is set to the diffusion cross-entropy during training.
"""
B, T, _ = x.shape
max_len = self.pos_emb.num_embeddings // 2
if T > max_len:
raise ValueError(f"Sequence length {T} exceeds max_len {max_len}")
# ── clean prefix embeddings ─────────────────────────────────────────
h_prefix = self.input_proj(x)
h_prefix = h_prefix + self.pos_emb(torch.arange(T, device=x.device))
# ── mask token embeddings ───────────────────────────────────────────
h_mask = self.mask_emb.expand(B, T, -1)
h_mask = h_mask + self.pos_emb(torch.arange(T, 2 * T, device=x.device))
h = torch.cat([h_prefix, h_mask], dim=1) # (B, 2T, d_model)
# ── forward through blocks with structured mask ─────────────────────
attn_mask = self._get_attn_mask(T, x.device) # (2T, 2T)
for block in self.blocks:
h = block(h, attn_mask)
h = self.ln_f(h)
logits = self.head(h) # (B, 2T, d_output)
# ── diffusion auxiliary loss ────────────────────────────────────────
if self.training:
diff_logits = logits[:, T:, :] # (B, T, d_output)
diff_targets = x.argmax(dim=-1) # (B, T) recover token idx from one-hot
diff_loss = F.cross_entropy(
diff_logits.reshape(-1, diff_logits.size(-1)),
diff_targets.reshape(-1),
)
# Scale: paper balances AR and diffusion with weight 1/(1+Ξ±) each
# train.py contributes Ξ±/(1+Ξ±)Β·L_AR; we contribute 1/(1+Ξ±)Β·L_Diff
self.aux_loss = diff_loss / (1.0 + self.alpha)
else:
self.aux_loss = x.new_zeros(())
return logits[:, :T, :] # AR logits only
# ------------------------------------------------------------------
@staticmethod
def extra_kwargs(model_cfg) -> dict:
return {
"n_heads": model_cfg.n_heads,
"block_size": getattr(model_cfg, "block_size", 8),
"alpha": getattr(model_cfg, "alpha", 1.0),
}