| """TiDAR: Think in Diffusion, Talk in Autoregression. |
| |
| Reference: Liu et al., arXiv:2511.08923 |
| |
| Training sequence structure (block_size=B, prefix length=T, total=2T): |
| |
| [ x_0, x_1, ..., x_{T-1} | M, M, ..., M ] |
| β clean prefix (AR) β β mask section β |
| |
| Structured attention mask |
| βββββββββββββββββββββββββ |
| β’ Clean prefix [0 : T]: causal (standard lower-triangular) |
| β’ Mask section [T : 2T]: full attention to clean prefix |
| + bidirectional within each B-token block |
| + causal between blocks |
| |
| Loss |
| ββββ |
| β’ AR loss (L_AR): computed externally by train.py on model output[:, :T, :] |
| β’ Diffusion loss (L_Diff): model predicts the original token at each mask position; |
| stored in self.aux_loss during training. |
| β’ Combined: L = (Ξ±Β·L_AR + L_Diff) / (1 + Ξ±) [paper eq., Ξ±=1 default] |
| β train.py adds aux_loss directly to the primary criterion output. |
| |
| Interface |
| βββββββββ |
| forward(x: Tensor[B, T, d_input]) β Tensor[B, T, d_output] (AR logits only) |
| self.aux_loss: scalar Tensor (diffusion CE, populated during training) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| def _build_tidar_mask(T: int, block_size: int, device: torch.device) -> torch.Tensor: |
| """Return a (2T, 2T) additive float mask: 0.0 where allowed, -inf where blocked.""" |
| S = 2 * T |
| mask = torch.full((S, S), float("-inf"), device=device) |
|
|
| idx = torch.arange(T, device=device) |
|
|
| |
| |
| causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device)) |
| mask[:T, :T] = causal.float().masked_fill(~causal, float("-inf")).masked_fill(causal, 0.0) |
|
|
| |
| |
|
|
| |
| mask[T:, :T] = 0.0 |
|
|
| |
| |
| bi = idx // block_size |
| allowed = bi.unsqueeze(1) >= bi.unsqueeze(0) |
| mask[T:, T:] = allowed.float().masked_fill(~allowed, float("-inf")).masked_fill(allowed, 0.0) |
|
|
| return mask |
|
|
|
|
| |
| |
| |
|
|
| class _MaskedSelfAttention(nn.Module): |
| """Multi-head self-attention with an explicit additive attention mask.""" |
|
|
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.proj = nn.Linear(d_model, d_model, bias=False) |
| self.attn_drop_p = dropout |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.shape |
| q, k, v = self.qkv(x).split(C, dim=-1) |
|
|
| def split_heads(t: torch.Tensor) -> torch.Tensor: |
| return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
|
|
| out = F.scaled_dot_product_attention( |
| split_heads(q), split_heads(k), split_heads(v), |
| attn_mask=attn_mask, |
| dropout_p=self.attn_drop_p if self.training else 0.0, |
| ) |
| return self.proj(out.transpose(1, 2).contiguous().view(B, T, C)) |
|
|
|
|
| class _TiDARBlock(nn.Module): |
| """Pre-LN Transformer block accepting an explicit additive attention mask.""" |
|
|
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(d_model) |
| self.attn = _MaskedSelfAttention(d_model, n_heads, dropout) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, 4 * d_model), nn.GELU(), |
| nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.ln1(x), attn_mask) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class TiDARModel(nn.Module): |
| """TiDAR: Think in Diffusion, Talk in Autoregression (arXiv:2511.08923). |
| |
| Doubles the sequence internally β the second half is a block of learned |
| [MASK] embeddings processed with the TiDAR structured attention mask. |
| Returns AR logits for the clean prefix only; diffusion auxiliary loss is |
| stored in self.aux_loss during training. |
| """ |
|
|
| def __init__( |
| self, |
| d_input: int, |
| d_model: int, |
| d_output: int, |
| n_layers: int = 2, |
| n_heads: int = 4, |
| block_size: int = 8, |
| alpha: float = 1.0, |
| max_len: int = 4096, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| super().__init__() |
| self.block_size = block_size |
| self.alpha = alpha |
|
|
| self.input_proj = nn.Linear(d_input, d_model) |
| |
| self.mask_emb = nn.Parameter(torch.empty(1, 1, d_model)) |
| nn.init.normal_(self.mask_emb, std=0.02) |
|
|
| self.pos_emb = nn.Embedding(2 * max_len, d_model) |
| self.blocks = nn.ModuleList([ |
| _TiDARBlock(d_model, n_heads, dropout) for _ in range(n_layers) |
| ]) |
| self.ln_f = nn.LayerNorm(d_model) |
| self.head = nn.Linear(d_model, d_output) |
|
|
| |
| self.aux_loss: torch.Tensor = torch.tensor(0.0) |
|
|
| |
| self._mask_cache: dict[tuple, torch.Tensor] = {} |
|
|
| |
|
|
| def _get_attn_mask(self, T: int, device: torch.device) -> torch.Tensor: |
| key = (T, device.type, getattr(device, "index", 0)) |
| if key not in self._mask_cache: |
| self._mask_cache[key] = _build_tidar_mask(T, self.block_size, device) |
| return self._mask_cache[key] |
|
|
| |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """(B, T, d_input) β (B, T, d_output) [AR logits for clean prefix] |
| |
| self.aux_loss is set to the diffusion cross-entropy during training. |
| """ |
| B, T, _ = x.shape |
| max_len = self.pos_emb.num_embeddings // 2 |
| if T > max_len: |
| raise ValueError(f"Sequence length {T} exceeds max_len {max_len}") |
|
|
| |
| h_prefix = self.input_proj(x) |
| h_prefix = h_prefix + self.pos_emb(torch.arange(T, device=x.device)) |
|
|
| |
| h_mask = self.mask_emb.expand(B, T, -1) |
| h_mask = h_mask + self.pos_emb(torch.arange(T, 2 * T, device=x.device)) |
|
|
| h = torch.cat([h_prefix, h_mask], dim=1) |
|
|
| |
| attn_mask = self._get_attn_mask(T, x.device) |
| for block in self.blocks: |
| h = block(h, attn_mask) |
|
|
| h = self.ln_f(h) |
| logits = self.head(h) |
|
|
| |
| if self.training: |
| diff_logits = logits[:, T:, :] |
| diff_targets = x.argmax(dim=-1) |
| diff_loss = F.cross_entropy( |
| diff_logits.reshape(-1, diff_logits.size(-1)), |
| diff_targets.reshape(-1), |
| ) |
| |
| |
| self.aux_loss = diff_loss / (1.0 + self.alpha) |
| else: |
| self.aux_loss = x.new_zeros(()) |
|
|
| return logits[:, :T, :] |
|
|
| |
|
|
| @staticmethod |
| def extra_kwargs(model_cfg) -> dict: |
| return { |
| "n_heads": model_cfg.n_heads, |
| "block_size": getattr(model_cfg, "block_size", 8), |
| "alpha": getattr(model_cfg, "alpha", 1.0), |
| } |
|
|