"""TiDAR: Think in Diffusion, Talk in Autoregression. Reference: Liu et al., arXiv:2511.08923 Training sequence structure (block_size=B, prefix length=T, total=2T): [ x_0, x_1, ..., x_{T-1} | M, M, ..., M ] ← clean prefix (AR) → ← mask section → Structured attention mask ───────────────────────── • Clean prefix [0 : T]: causal (standard lower-triangular) • Mask section [T : 2T]: full attention to clean prefix + bidirectional within each B-token block + causal between blocks Loss ──── • AR loss (L_AR): computed externally by train.py on model output[:, :T, :] • Diffusion loss (L_Diff): model predicts the original token at each mask position; stored in self.aux_loss during training. • Combined: L = (α·L_AR + L_Diff) / (1 + α) [paper eq., α=1 default] — train.py adds aux_loss directly to the primary criterion output. Interface ───────── forward(x: Tensor[B, T, d_input]) → Tensor[B, T, d_output] (AR logits only) self.aux_loss: scalar Tensor (diffusion CE, populated during training) """ import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Attention mask # --------------------------------------------------------------------------- def _build_tidar_mask(T: int, block_size: int, device: torch.device) -> torch.Tensor: """Return a (2T, 2T) additive float mask: 0.0 where allowed, -inf where blocked.""" S = 2 * T mask = torch.full((S, S), float("-inf"), device=device) idx = torch.arange(T, device=device) # (T,) # ── top-left [0:T, 0:T]: causal self-attention for clean prefix ────────── # mask[i, j] = 0 iff j ≤ i causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device)) mask[:T, :T] = causal.float().masked_fill(~causal, float("-inf")).masked_fill(causal, 0.0) # ── top-right [0:T, T:2T]: prefix never attends to mask tokens ─────────── # (already -inf from initialization) # ── bottom-left [T:2T, 0:T]: mask tokens see the entire clean prefix ───── mask[T:, :T] = 0.0 # ── bottom-right [T:2T, T:2T]: block-causal + intra-block bidirectional ── # mask[T+i, T+j] = 0 iff block(j) ≤ block(i) bi = idx // block_size # (T,) allowed = bi.unsqueeze(1) >= bi.unsqueeze(0) # allowed[i, j] = (block_i >= block_j) mask[T:, T:] = allowed.float().masked_fill(~allowed, float("-inf")).masked_fill(allowed, 0.0) return mask # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- class _MaskedSelfAttention(nn.Module): """Multi-head self-attention with an explicit additive attention mask.""" def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.proj = nn.Linear(d_model, d_model, bias=False) self.attn_drop_p = dropout def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: B, T, C = x.shape q, k, v = self.qkv(x).split(C, dim=-1) def split_heads(t: torch.Tensor) -> torch.Tensor: return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2) out = F.scaled_dot_product_attention( split_heads(q), split_heads(k), split_heads(v), attn_mask=attn_mask, dropout_p=self.attn_drop_p if self.training else 0.0, ) return self.proj(out.transpose(1, 2).contiguous().view(B, T, C)) class _TiDARBlock(nn.Module): """Pre-LN Transformer block accepting an explicit additive attention mask.""" def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = _MaskedSelfAttention(d_model, n_heads, dropout) self.ln2 = nn.LayerNorm(d_model) self.mlp = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x), attn_mask) x = x + self.mlp(self.ln2(x)) return x # --------------------------------------------------------------------------- # TiDAR model # --------------------------------------------------------------------------- class TiDARModel(nn.Module): """TiDAR: Think in Diffusion, Talk in Autoregression (arXiv:2511.08923). Doubles the sequence internally — the second half is a block of learned [MASK] embeddings processed with the TiDAR structured attention mask. Returns AR logits for the clean prefix only; diffusion auxiliary loss is stored in self.aux_loss during training. """ def __init__( self, d_input: int, d_model: int, d_output: int, n_layers: int = 2, n_heads: int = 4, block_size: int = 8, alpha: float = 1.0, max_len: int = 4096, dropout: float = 0.0, **kwargs, ): super().__init__() self.block_size = block_size self.alpha = alpha self.input_proj = nn.Linear(d_input, d_model) # Learned [MASK] embedding shared across all mask positions (bias absorbed by pos_emb) self.mask_emb = nn.Parameter(torch.empty(1, 1, d_model)) nn.init.normal_(self.mask_emb, std=0.02) self.pos_emb = nn.Embedding(2 * max_len, d_model) self.blocks = nn.ModuleList([ _TiDARBlock(d_model, n_heads, dropout) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, d_output) # Populated each forward call during training; zero otherwise self.aux_loss: torch.Tensor = torch.tensor(0.0) # Cache attention masks by (T, device) to avoid recomputing self._mask_cache: dict[tuple, torch.Tensor] = {} # ------------------------------------------------------------------ def _get_attn_mask(self, T: int, device: torch.device) -> torch.Tensor: key = (T, device.type, getattr(device, "index", 0)) if key not in self._mask_cache: self._mask_cache[key] = _build_tidar_mask(T, self.block_size, device) return self._mask_cache[key] # ------------------------------------------------------------------ def forward(self, x: torch.Tensor) -> torch.Tensor: """(B, T, d_input) → (B, T, d_output) [AR logits for clean prefix] self.aux_loss is set to the diffusion cross-entropy during training. """ B, T, _ = x.shape max_len = self.pos_emb.num_embeddings // 2 if T > max_len: raise ValueError(f"Sequence length {T} exceeds max_len {max_len}") # ── clean prefix embeddings ───────────────────────────────────────── h_prefix = self.input_proj(x) h_prefix = h_prefix + self.pos_emb(torch.arange(T, device=x.device)) # ── mask token embeddings ─────────────────────────────────────────── h_mask = self.mask_emb.expand(B, T, -1) h_mask = h_mask + self.pos_emb(torch.arange(T, 2 * T, device=x.device)) h = torch.cat([h_prefix, h_mask], dim=1) # (B, 2T, d_model) # ── forward through blocks with structured mask ───────────────────── attn_mask = self._get_attn_mask(T, x.device) # (2T, 2T) for block in self.blocks: h = block(h, attn_mask) h = self.ln_f(h) logits = self.head(h) # (B, 2T, d_output) # ── diffusion auxiliary loss ──────────────────────────────────────── if self.training: diff_logits = logits[:, T:, :] # (B, T, d_output) diff_targets = x.argmax(dim=-1) # (B, T) recover token idx from one-hot diff_loss = F.cross_entropy( diff_logits.reshape(-1, diff_logits.size(-1)), diff_targets.reshape(-1), ) # Scale: paper balances AR and diffusion with weight 1/(1+α) each # train.py contributes α/(1+α)·L_AR; we contribute 1/(1+α)·L_Diff self.aux_loss = diff_loss / (1.0 + self.alpha) else: self.aux_loss = x.new_zeros(()) return logits[:, :T, :] # AR logits only # ------------------------------------------------------------------ @staticmethod def extra_kwargs(model_cfg) -> dict: return { "n_heads": model_cfg.n_heads, "block_size": getattr(model_cfg, "block_size", 8), "alpha": getattr(model_cfg, "alpha", 1.0), }