aam-diffusion-v1 / diffusion_llm /model /speculative_decoder.py
Wolfvin's picture
Upload diffusion_llm/model/speculative_decoder.py with huggingface_hub
1a3952b verified
"""AAM Diffusion LLM — Speculative Decoder
Draft model (graph encoder quick prediction) generates candidates,
main diffusion model verifies and accepts/rejects.
For AAM, the graph encoder can serve as the draft model since
it already produces a quick prediction of the narrative.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class SpeculativeConfig:
d_model: int = 768
d_vocab: int = 32000
n_draft_tokens: int = 5
acceptance_threshold: float = 0.1
class SpeculativeDecoder(nn.Module):
"""Speculative Decoder for AAM.
Uses graph encoder as draft model, diffusion model as verifier.
1. Draft: graph encoder produces quick token predictions
2. Verify: diffusion model evaluates each prediction
3. Accept/Reject: keep tokens that pass verification
"""
def __init__(self, config: Optional[SpeculativeConfig] = None) -> None:
super().__init__()
self.config = config or SpeculativeConfig()
self.d_model = self.config.d_model
self.d_vocab = self.config.d_vocab
# Draft head (lightweight, from graph conditioning)
self.draft_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model // 2, bias=False),
nn.SiLU(),
nn.Linear(self.d_model // 2, self.d_vocab, bias=False),
)
# Verification projection
self.verify_proj = nn.Sequential(
nn.Linear(self.d_model, self.d_model, bias=False),
nn.SiLU(),
nn.Linear(self.d_model, self.d_vocab, bias=False),
)
def draft(
self,
graph_hidden: torch.Tensor,
n_tokens: int = 5,
temperature: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate draft tokens from graph conditioning.
Args:
graph_hidden: Graph encoder output (batch, n_nodes, d_model)
n_tokens: Number of draft tokens to generate
temperature: Sampling temperature
Returns:
Tuple (draft_token_ids, draft_log_probs)
"""
batch_size = graph_hidden.shape[0]
device = graph_hidden.device
# Use mean-pooled graph representation
pooled = graph_hidden.mean(dim=1) # (batch, d_model)
all_tokens = []
all_log_probs = []
current = pooled
for _ in range(n_tokens):
logits = self.draft_head(current) / temperature
log_probs = F.log_softmax(logits, dim=-1)
probs = torch.exp(log_probs)
token_ids = torch.multinomial(probs, 1).squeeze(-1)
all_tokens.append(token_ids)
selected_log_probs = log_probs.gather(-1, token_ids.unsqueeze(-1)).squeeze(-1)
all_log_probs.append(selected_log_probs)
draft_token_ids = torch.stack(all_tokens, dim=1) # (batch, n_tokens)
draft_log_probs = torch.stack(all_log_probs, dim=1) # (batch, n_tokens)
return draft_token_ids, draft_log_probs
def verify(
self,
draft_token_ids: torch.Tensor,
main_logits: torch.Tensor,
draft_log_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Verify draft tokens against main model logits.
Args:
draft_token_ids: Draft token IDs (batch, n_tokens)
main_logits: Main model logits (batch, n_tokens, d_vocab)
draft_log_probs: Draft log probs (batch, n_tokens)
Returns:
Tuple (accepted_mask, verified_token_ids)
"""
main_log_probs = F.log_softmax(main_logits, dim=-1)
selected_main_log_probs = main_log_probs.gather(-1, draft_token_ids.unsqueeze(-1)).squeeze(-1)
# Accept if main model's probability is close to or higher than draft
ratio = torch.exp(selected_main_log_probs - draft_log_probs)
accepted = (ratio >= (1.0 - self.config.acceptance_threshold)).float()
# Where rejected, sample from main model
rejected_mask = (accepted == 0).bool()
if rejected_mask.any():
main_probs = torch.exp(main_log_probs)
resampled = torch.multinomial(
main_probs.view(-1, self.d_vocab), 1
).view(draft_token_ids.shape)
verified = torch.where(rejected_mask, resampled, draft_token_ids)
else:
verified = draft_token_ids
return accepted, verified
def forward(
self,
graph_hidden: torch.Tensor,
main_model_fn=None,
n_tokens: Optional[int] = None,
) -> Tuple[torch.Tensor, dict]:
"""Full speculative decoding pipeline.
Args:
graph_hidden: Graph encoder output
main_model_fn: Callable that takes token_ids → logits
n_tokens: Number of draft tokens
Returns:
Tuple (verified_token_ids, info_dict)
"""
n_tokens = n_tokens or self.config.n_draft_tokens
# Draft
draft_ids, draft_log_probs = self.draft(graph_hidden, n_tokens)
if main_model_fn is not None:
# Verify with main model
main_logits = main_model_fn(draft_ids)
accepted, verified = self.verify(draft_ids, main_logits, draft_log_probs)
info = {
"n_draft": n_tokens,
"n_accepted": accepted.sum(dim=-1).mean().item(),
"acceptance_rate": accepted.mean().item(),
}
return verified, info
else:
# No verification — just return draft
return draft_ids, {"n_draft": n_tokens, "acceptance_rate": 1.0}