aam-diffusion-v1 / diffusion_llm /model /diffusion_transformer.py
Wolfvin's picture
Upload diffusion_llm/model/diffusion_transformer.py with huggingface_hub
087c85d verified
"""
AAM Diffusion LLM — Diffusion Transformer (Denoiser)
The core denoising network. Takes noisy text embeddings and graph
conditioning, and predicts the noise (or clean data) at each
diffusion timestep.
Architecture:
Input: Noisy embeddings x_t + timestep t + graph conditioning
Output: Predicted noise epsilon (or x_0 or v)
The transformer uses:
- Self-attention over the text sequence
- Cross-attention to graph conditioning (evidence, anomalies, etc.)
- Timestep embedding (sinusoidal) injected via adaptive layer norm
- Optional flash attention for efficiency
- [v2.0] SwiGLU FFN (proven better in LLaMA/Mistral)
- [v2.0] RoPE via the dedicated rope.py module
- [v2.0] Evoformer integration points for layer recycling
This is the "brainstem" of the body — the core computation that
transforms noisy signals into coherent patterns.
Analogi: Seperti otot Jin Soun yang merespons sinyal dari otak —
model ini menerima "sinyal noise" dan "instruksi dari graph",
lalu mengubahnya menjadi gerakan yang koheren (kalimat).
"""
from __future__ import annotations
import math
from typing import Optional, List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusion_llm.config.model_config import ModelConfig, EvoformerConfig, MatryoshkaConfig
from diffusion_llm.model.rope import RotaryPositionEncoding
class SinusoidalTimestepEmbedding(nn.Module):
"""Sinusoidal embedding for diffusion timesteps.
Maps integer timesteps to d_model-dimensional vectors using
sinusoidal position encoding, similar to Transformers.
This allows the model to know "how noisy" the current input is,
which is essential for the denoising process.
"""
def __init__(self, d_model: int, max_period: int = 10000):
super().__init__()
self.d_model = d_model
self.max_period = max_period
# Two-layer MLP to project sinusoidal features
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""Embed timesteps.
Args:
t: Timestep indices of shape (batch,).
Returns:
Timestep embeddings of shape (batch, d_model).
"""
device = t.device
half_dim = self.d_model // 2
emb = math.log(self.max_period) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if emb.shape[-1] < self.d_model:
emb = F.pad(emb, (0, self.d_model - emb.shape[-1]))
return self.mlp(emb)
class AdaptiveLayerNorm(nn.Module):
"""Adaptive Layer Normalization conditioned on timestep.
Instead of fixed scale/shift parameters, this layer norm
uses the timestep embedding to produce scale and shift:
y = (1 + scale(t)) * norm(x) + shift(t)
This allows the model to behave differently at different
noise levels — more "creative" at high noise, more
"precise" at low noise.
Analogi: Jin Soun menyesuaikan intensitas pikirannya
berdasarkan seberapa kabur situasinya — semakin kabur,
semakin "kreatif" pendekatannya.
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=eps)
self.scale_proj = nn.Linear(d_model, d_model)
self.shift_proj = nn.Linear(d_model, d_model)
# Initialize shift to zero, scale to one
nn.init.zeros_(self.shift_proj.weight)
nn.init.zeros_(self.shift_proj.bias)
nn.init.ones_(self.scale_proj.weight)
nn.init.zeros_(self.scale_proj.bias)
def forward(
self,
x: torch.Tensor,
timestep_emb: torch.Tensor,
) -> torch.Tensor:
"""Apply adaptive layer norm.
Args:
x: Input tensor of shape (batch, seq_len, d_model).
timestep_emb: Timestep embedding of shape (batch, d_model).
Returns:
Normalized and modulated tensor.
"""
normalized = self.norm(x)
scale = (1 + self.scale_proj(timestep_emb)).unsqueeze(1)
shift = self.shift_proj(timestep_emb).unsqueeze(1)
return normalized * scale + shift
class TransformerBlock(nn.Module):
"""Single transformer block with self-attention, cross-attention, and FFN.
The block structure:
1. Adaptive Layer Norm + Self-Attention
2. Adaptive Layer Norm + Cross-Attention (to graph conditioning)
3. Adaptive Layer Norm + Feed-Forward Network
Each sub-layer has a residual connection.
v2.0 Changes:
- SwiGLU FFN replaces GELU FFN (proven better in LLaMA/Mistral)
- Optional Matryoshka elastic inference on the FFN
"""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
dropout: float = 0.1,
norm_eps: float = 1e-6,
norm_type: str = "rmsnorm",
use_flash_attention: bool = True,
use_swiglu: bool = True,
matryoshka_config: Optional[MatryoshkaConfig] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.use_swiglu = use_swiglu
self.matryoshka_config = matryoshka_config
# Norms
NormClass = nn.RMSNorm if norm_type == "rmsnorm" else nn.LayerNorm
# Self-attention
self.self_attn_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=dropout,
batch_first=True,
)
self.self_attn_dropout = nn.Dropout(dropout)
# Cross-attention (to graph conditioning)
self.cross_attn_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=dropout,
batch_first=True,
kdim=d_model,
vdim=d_model,
)
self.cross_attn_dropout = nn.Dropout(dropout)
# Feed-forward — SwiGLU or legacy GELU
self.ff_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
if use_swiglu:
# SwiGLU FFN (proven better in LLaMA/Mistral)
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
self.ff_dropout = nn.Dropout(dropout)
else:
# Legacy GELU FFN (backward compatible)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
# Matryoshka elastic inference (optional)
if matryoshka_config is not None and use_swiglu:
self._matryoshka_d_ff = d_ff
self._matryoshka_factors = sorted(matryoshka_config.granularity_factors)
if matryoshka_config.use_adaptive:
self.size_selector = nn.Sequential(
nn.Linear(d_model, d_model // 8, bias=False),
nn.SiLU(),
nn.Linear(d_model // 8, 1, bias=False),
nn.Sigmoid(),
)
else:
self._matryoshka_d_ff = None
self._matryoshka_factors = None
# Layer scales (optional, helps with deep networks)
self.self_attn_scale = nn.Parameter(torch.ones(1) * 0.1)
self.cross_attn_scale = nn.Parameter(torch.ones(1) * 0.1)
self.ff_scale = nn.Parameter(torch.ones(1) * 0.1)
def _select_matryoshka_factor(self, x: torch.Tensor) -> float:
"""Adaptive factor selection for Matryoshka inference."""
if not hasattr(self, "size_selector"):
return 1.0
score = self.size_selector(x.mean(dim=1, keepdim=False))
score_val = score.mean().item()
min_dist = float("inf")
best_factor = self._matryoshka_factors[-1]
for f in self._matryoshka_factors:
dist = abs(score_val - f)
if dist < min_dist:
min_dist = dist
best_factor = f
return best_factor
def forward(
self,
x: torch.Tensor,
timestep_emb: torch.Tensor,
graph_keys: Optional[torch.Tensor] = None,
graph_values: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
granularity_factor: Optional[float] = None,
) -> torch.Tensor:
"""Forward pass.
Args:
x: Input sequence of shape (batch, seq_len, d_model).
timestep_emb: Timestep embedding of shape (batch, d_model).
graph_keys: Graph conditioning keys for cross-attention,
shape (batch, n_graph_nodes, d_model).
graph_values: Graph conditioning values for cross-attention,
shape (batch, n_graph_nodes, d_model).
causal_mask: Optional causal mask for self-attention.
granularity_factor: Optional Matryoshka granularity factor
for elastic inference (1.0 = full size).
Returns:
Output sequence of shape (batch, seq_len, d_model).
"""
# 1. Self-attention with adaptive layer norm
normed = self.self_attn_norm(x, timestep_emb)
attn_out, _ = self.self_attn(
normed, normed, normed,
attn_mask=causal_mask,
need_weights=False,
)
x = x + self.self_attn_scale * self.self_attn_dropout(attn_out)
# 2. Cross-attention to graph conditioning (if available)
if graph_keys is not None and graph_values is not None:
normed = self.cross_attn_norm(x, timestep_emb)
cross_out, _ = self.cross_attn(
normed, graph_keys, graph_values,
need_weights=False,
)
x = x + self.cross_attn_scale * self.cross_attn_dropout(cross_out)
# 3. Feed-forward with adaptive layer norm
normed = self.ff_norm(x, timestep_emb)
if self.use_swiglu:
# Determine Matryoshka factor
factor = granularity_factor
if factor is None and self._matryoshka_factors is not None:
factor = self._select_matryoshka_factor(normed)
elif factor is None:
factor = 1.0
# Clamp factor
if self._matryoshka_factors is not None:
factor = min(max(factor, min(self._matryoshka_factors)), 1.0)
else:
factor = 1.0
d_ff_active = max(1, int(self._matryoshka_d_ff * factor)) if self._matryoshka_d_ff else self.gate_proj.out_features
if factor >= 1.0 or self._matryoshka_d_ff is None:
# Full-size SwiGLU
gate = F.silu(self.gate_proj(normed))
up = self.up_proj(normed)
ff_out = self.down_proj(gate * up)
else:
# Matryoshka partial SwiGLU
d_ff_active = max(1, int(self._matryoshka_d_ff * factor))
gate = F.silu(F.linear(normed, self.gate_proj.weight[:d_ff_active, :]))
up = F.linear(normed, self.up_proj.weight[:d_ff_active, :])
ff_out = F.linear(gate * up, self.down_proj.weight[:, :d_ff_active])
ff_out = self.ff_dropout(ff_out)
else:
# Legacy GELU FFN
ff_out = self.ff(normed)
x = x + self.ff_scale * ff_out
return x
class DiffusionTransformer(nn.Module):
"""Diffusion Transformer — the core denoising network for AAM.
This transformer takes:
- Noisy text embeddings (x_t)
- Diffusion timestep (t)
- Graph conditioning (evidence, anomalies, reasoning chains)
And predicts the noise that was added (or the clean data,
depending on prediction_type).
Architecture Overview:
┌────────────────────────────────────────────────┐
│ Input Embedding: x_t (noisy) → embedding │
│ + Positional Encoding (RoPE or learned) │
│ │
│ N x TransformerBlock: │
│ ├─ AdaLN + Self-Attention │
│ ├─ AdaLN + Cross-Attention (to graph) │
│ └─ AdaLN + SwiGLU FFN (Matryoshka) │
│ │
│ [Evoformer Layer Recycling — optional] │
│ │
│ Output Projection: → predicted noise │
└────────────────────────────────────────────────┘
Key Features:
- Adaptive Layer Norm: timestep-conditioned normalization
- Cross-Attention: graph conditioning guides generation
- Layer Scales: helps training deep networks
- RoPE: better length generalization than learned positions
- [v2.0] SwiGLU FFN: proven better than GELU in LLaMA/Mistral
- [v2.0] Matryoshka: elastic inference at multiple sizes
- [v2.0] Evoformer: layer recycling for iterative refinement
Args:
config: ModelConfig with architecture hyperparameters.
evoformer_config: Optional EvoformerConfig for layer recycling.
matryoshka_config: Optional MatryoshkaConfig for elastic inference.
use_swiglu: Whether to use SwiGLU FFN (default True for v2.0).
"""
def __init__(
self,
config: ModelConfig,
evoformer_config: Optional[EvoformerConfig] = None,
matryoshka_config: Optional[MatryoshkaConfig] = None,
use_swiglu: bool = True,
):
super().__init__()
self.config = config
self.evoformer_config = evoformer_config
self.matryoshka_config = matryoshka_config
self.use_swiglu = use_swiglu
# Input embedding (from token IDs to d_model)
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
# Timestep embedding
self.timestep_embedding = SinusoidalTimestepEmbedding(config.d_model)
# Positional encoding
if config.pos_encoding_type == "learned":
self.position_embedding = nn.Embedding(
config.max_seq_len, config.d_model
)
self.rope = None
else:
# RoPE is applied inside attention (no separate embedding)
self.position_embedding = None
# v2.0: Create RotaryPositionEncoding module for explicit RoPE
self.rope = RotaryPositionEncoding(
d_model=config.d_model,
max_seq_len=config.max_seq_len,
)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
d_model=config.d_model,
n_heads=config.n_heads,
d_ff=config.d_ff,
dropout=config.dropout,
norm_eps=config.norm_eps,
norm_type=config.norm_type,
use_flash_attention=config.use_flash_attention,
use_swiglu=use_swiglu,
matryoshka_config=matryoshka_config,
)
for _ in range(config.n_layers)
])
# Final norm
NormClass = nn.RMSNorm if config.norm_type == "rmsnorm" else nn.LayerNorm
self.final_norm = NormClass(config.d_model, eps=config.norm_eps)
# Output projection (predict noise/x0/v)
self.output_proj = nn.Linear(config.d_model, config.d_model)
# Evoformer integration — lazy import to avoid circular deps
self._evoformer_manager = None
if evoformer_config is not None:
self._init_evoformer(evoformer_config)
# Initialize weights
self.apply(self._init_weights)
def _init_evoformer(self, evoformer_config: EvoformerConfig) -> None:
"""Initialize the Evoformer manager for layer recycling."""
from diffusion_llm.model.evoformer import EvoformerManager
self._evoformer_manager = EvoformerManager(evoformer_config)
def _init_weights(self, module: nn.Module) -> None:
"""Initialize weights with Xavier/GPT-2 style."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
def forward(
self,
x_t: torch.Tensor,
t: torch.Tensor,
token_ids: Optional[torch.Tensor] = None,
graph_keys: Optional[torch.Tensor] = None,
graph_values: Optional[torch.Tensor] = None,
granularity_factor: Optional[float] = None,
return_hidden_states: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward pass: predict noise given noisy input and timestep.
Args:
x_t: Noisy text embeddings of shape (batch, seq_len, d_model).
If None, token_ids must be provided.
t: Timestep indices of shape (batch,).
token_ids: Token IDs of shape (batch, seq_len).
Used to create embeddings if x_t is not provided directly.
In training, x_t comes from the noise scheduler.
graph_keys: Graph conditioning keys for cross-attention,
shape (batch, n_graph_nodes, d_model).
graph_values: Graph conditioning values for cross-attention,
shape (batch, n_graph_nodes, d_model).
granularity_factor: Optional Matryoshka granularity factor
for elastic inference (1.0 = full size).
return_hidden_states: If True, also return per-layer hidden
states for Evoformer layer recycling.
Returns:
Predicted noise of shape (batch, seq_len, d_model).
If return_hidden_states is True, also returns a list of
per-layer hidden state tensors.
"""
# Get input embeddings
if x_t is None and token_ids is not None:
# Create embeddings from token IDs (used for initial x_0)
h = self.token_embedding(token_ids)
elif x_t is not None:
h = x_t
else:
raise ValueError("Either x_t or token_ids must be provided")
# Add positional encoding (learned positions only; RoPE is applied in attention)
if self.position_embedding is not None:
seq_len = h.shape[1]
positions = torch.arange(seq_len, device=h.device).unsqueeze(0)
h = h + self.position_embedding(positions)
# Embed timestep
t_emb = self.timestep_embedding(t)
# Pass through transformer blocks, collecting hidden states for Evoformer
hidden_states: List[torch.Tensor] = []
for block in self.blocks:
h = block(
h,
timestep_emb=t_emb,
graph_keys=graph_keys,
graph_values=graph_values,
granularity_factor=granularity_factor,
)
if return_hidden_states or self._evoformer_manager is not None:
hidden_states.append(h)
# Evoformer layer recycling (if enabled)
if self._evoformer_manager is not None and len(hidden_states) > 1:
hidden_states = self._evoformer_manager.recycle_layers(hidden_states)
# Use the last revised hidden state as the output
h = hidden_states[-1]
# Final norm and projection
h = self.final_norm(h)
output = self.output_proj(h)
if return_hidden_states:
return output, hidden_states
return output
def apply_evoformer_token_update(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Evoformer bidirectional token update (Level 2).
Can be called externally as part of an Evoformer recycling loop.
Args:
x: Hidden state tensor of shape (batch, seq_len, d_model).
Returns:
Revised hidden state tensor.
"""
if self._evoformer_manager is not None:
return self._evoformer_manager.bidirectional_token_update(x)
return x
def apply_evoformer_decoder_feedback(
self,
hidden_state: torch.Tensor,
decoder_output: torch.Tensor,
) -> torch.Tensor:
"""Apply Evoformer decoder-predict feedback (Level 3).
Can be called externally during anchored decoder refinement.
Args:
hidden_state: Hidden state tensor of shape (batch, seq_len, d_model).
decoder_output: Decoder output tensor of shape (batch, seq_len, d_model).
Returns:
Revised hidden state tensor.
"""
if self._evoformer_manager is not None:
return self._evoformer_manager.apply_decoder_feedback(hidden_state, decoder_output)
return hidden_state
def apply_evoformer_prediction_recycling(
self,
hidden_states: torch.Tensor,
prediction_logits: torch.Tensor,
) -> torch.Tensor:
"""Apply Evoformer prediction-context recycling (Level 4).
Can be called externally to refine graph understanding from
predicted output.
Args:
hidden_states: Hidden states of shape (batch, seq_len, d_model).
prediction_logits: Prediction logits of shape (batch, seq_len, d_model).
Returns:
Revised hidden state tensor.
"""
if self._evoformer_manager is not None:
return self._evoformer_manager.apply_prediction_recycling(hidden_states, prediction_logits)
return hidden_states
def get_num_params(self) -> int:
"""Get total number of parameters."""
return sum(p.numel() for p in self.parameters())
def get_num_trainable_params(self) -> int:
"""Get number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)