| """AAM Diffusion LLM — Flow Matching Decoder |
| |
| Alternative to DDPM/DDIM — only 2-3 steps because starting point |
| is already meaningful (graph-conditioned prediction). |
| |
| Flow matching = velocity prediction (more stable for text), |
| doesn't need noise schedule. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class FlowMatchingOutput: |
| refined_logits: torch.Tensor |
| num_steps: int |
| trajectory: Optional[List[torch.Tensor]] |
|
|
|
|
| class FlowStep(nn.Module): |
| """Single step flow matching — predicts velocity field.""" |
|
|
| def __init__(self, d_model: int, time_embed_dim: Optional[int] = None) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.time_embed_dim = time_embed_dim or d_model // 4 |
|
|
| self.time_mlp = nn.Sequential( |
| nn.Linear(self.time_embed_dim, d_model), |
| nn.SiLU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| self.velocity_net = nn.Sequential( |
| nn.Linear(d_model * 2, d_model), |
| nn.SiLU(), |
| nn.Linear(d_model, d_model), |
| nn.SiLU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| self.layer_norm = nn.LayerNorm(d_model) |
|
|
| @staticmethod |
| def sinusoidal_embedding(t: torch.Tensor, dim: int) -> torch.Tensor: |
| if t.dim() == 1: |
| t = t.unsqueeze(-1) |
| half_dim = dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb) |
| emb = t * emb.unsqueeze(0) |
| emb = torch.cat([emb.sin(), emb.cos()], dim=-1) |
| if dim % 2 == 1: |
| emb = F.pad(emb, (0, 1)) |
| return emb |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| batch_size, seq_len, _ = x.shape |
| t_emb = self.sinusoidal_embedding(t, self.time_embed_dim) |
| t_emb = self.time_mlp(t_emb) |
| t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1) |
| velocity_input = torch.cat([x, t_emb], dim=-1) |
| velocity = self.velocity_net(velocity_input) |
| return velocity |
|
|
|
|
| class FlowMatchingDecoder(nn.Module): |
| """Flow Matching Decoder — 2-3 step refinement alternative to DDPM/DDIM. |
| |
| Flow matching formula: |
| - x_0 = initial prediction (from graph-conditioned denoising) |
| - x_1 = refined prediction (target) |
| - dx/dt = v(x, t) — velocity field |
| - x_{t+dt} = x_t + v(x_t, t) * dt — Euler step |
| """ |
|
|
| def __init__(self, d_model: int, d_vocab: int, num_steps: int = 3) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.d_vocab = d_vocab |
| self.num_steps = max(1, num_steps) |
|
|
| self.logits_to_hidden = nn.Linear(d_vocab, d_model, bias=False) |
| self.hidden_to_logits = nn.Linear(d_model, d_vocab, bias=False) |
|
|
| self.flow_steps = nn.ModuleList( |
| [FlowStep(d_model) for _ in range(self.num_steps)] |
| ) |
|
|
| self.input_norm = nn.LayerNorm(d_model) |
| self.output_norm = nn.LayerNorm(d_model) |
|
|
| self.register_buffer( |
| "time_schedule", |
| torch.linspace(0, 1, self.num_steps + 1), |
| ) |
|
|
| def forward( |
| self, |
| initial_hidden: torch.Tensor, |
| return_trajectory: bool = False, |
| ) -> FlowMatchingOutput: |
| x = self.input_norm(initial_hidden) |
| trajectory: List[torch.Tensor] = [] |
| if return_trajectory: |
| trajectory.append(x.clone()) |
|
|
| batch_size = x.shape[0] |
|
|
| for step_idx in range(self.num_steps): |
| t_start = self.time_schedule[step_idx] |
| t_end = self.time_schedule[step_idx + 1] |
| dt = t_end - t_start |
|
|
| t = t_start.expand(batch_size).to(x.device) |
| velocity = self.flow_steps[step_idx](x, t) |
| x = x + velocity * dt |
|
|
| if return_trajectory: |
| trajectory.append(x.clone()) |
|
|
| x = self.output_norm(x) |
| refined_logits = self.hidden_to_logits(x) |
|
|
| return FlowMatchingOutput( |
| refined_logits=refined_logits, |
| num_steps=self.num_steps, |
| trajectory=trajectory if return_trajectory else None, |
| ) |
|
|
| def compute_loss( |
| self, |
| initial_hidden: torch.Tensor, |
| target_hidden: torch.Tensor, |
| ) -> torch.Tensor: |
| batch_size = initial_hidden.shape[0] |
| device = initial_hidden.device |
| dtype = initial_hidden.dtype |
|
|
| x_0 = self.input_norm(initial_hidden) |
| x_1 = target_hidden |
|
|
| t = torch.rand(batch_size, device=device, dtype=dtype) |
| t_expand = t.view(-1, 1, 1) |
| x_t = (1 - t_expand) * x_0 + t_expand * x_1 |
|
|
| target_velocity = x_1 - x_0 |
|
|
| step_idx = torch.randint(0, self.num_steps, (1,)).item() |
| predicted_velocity = self.flow_steps[step_idx](x_t, t) |
|
|
| loss = F.mse_loss(predicted_velocity, target_velocity) |
| return loss |
|
|