aam-diffusion-v1 / diffusion_llm /model /flow_matching.py
Wolfvin's picture
Upload diffusion_llm/model/flow_matching.py with huggingface_hub
d158200 verified
"""AAM Diffusion LLM — Flow Matching Decoder
Alternative to DDPM/DDIM — only 2-3 steps because starting point
is already meaningful (graph-conditioned prediction).
Flow matching = velocity prediction (more stable for text),
doesn't need noise schedule.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class FlowMatchingOutput:
refined_logits: torch.Tensor
num_steps: int
trajectory: Optional[List[torch.Tensor]]
class FlowStep(nn.Module):
"""Single step flow matching — predicts velocity field."""
def __init__(self, d_model: int, time_embed_dim: Optional[int] = None) -> None:
super().__init__()
self.d_model = d_model
self.time_embed_dim = time_embed_dim or d_model // 4
self.time_mlp = nn.Sequential(
nn.Linear(self.time_embed_dim, d_model),
nn.SiLU(),
nn.Linear(d_model, d_model),
)
self.velocity_net = nn.Sequential(
nn.Linear(d_model * 2, d_model),
nn.SiLU(),
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, d_model),
)
self.layer_norm = nn.LayerNorm(d_model)
@staticmethod
def sinusoidal_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb)
emb = t * emb.unsqueeze(0)
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
if dim % 2 == 1:
emb = F.pad(emb, (0, 1))
return emb
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
t_emb = self.sinusoidal_embedding(t, self.time_embed_dim)
t_emb = self.time_mlp(t_emb)
t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)
velocity_input = torch.cat([x, t_emb], dim=-1)
velocity = self.velocity_net(velocity_input)
return velocity
class FlowMatchingDecoder(nn.Module):
"""Flow Matching Decoder — 2-3 step refinement alternative to DDPM/DDIM.
Flow matching formula:
- x_0 = initial prediction (from graph-conditioned denoising)
- x_1 = refined prediction (target)
- dx/dt = v(x, t) — velocity field
- x_{t+dt} = x_t + v(x_t, t) * dt — Euler step
"""
def __init__(self, d_model: int, d_vocab: int, num_steps: int = 3) -> None:
super().__init__()
self.d_model = d_model
self.d_vocab = d_vocab
self.num_steps = max(1, num_steps)
self.logits_to_hidden = nn.Linear(d_vocab, d_model, bias=False)
self.hidden_to_logits = nn.Linear(d_model, d_vocab, bias=False)
self.flow_steps = nn.ModuleList(
[FlowStep(d_model) for _ in range(self.num_steps)]
)
self.input_norm = nn.LayerNorm(d_model)
self.output_norm = nn.LayerNorm(d_model)
self.register_buffer(
"time_schedule",
torch.linspace(0, 1, self.num_steps + 1),
)
def forward(
self,
initial_hidden: torch.Tensor,
return_trajectory: bool = False,
) -> FlowMatchingOutput:
x = self.input_norm(initial_hidden)
trajectory: List[torch.Tensor] = []
if return_trajectory:
trajectory.append(x.clone())
batch_size = x.shape[0]
for step_idx in range(self.num_steps):
t_start = self.time_schedule[step_idx]
t_end = self.time_schedule[step_idx + 1]
dt = t_end - t_start
t = t_start.expand(batch_size).to(x.device)
velocity = self.flow_steps[step_idx](x, t)
x = x + velocity * dt
if return_trajectory:
trajectory.append(x.clone())
x = self.output_norm(x)
refined_logits = self.hidden_to_logits(x)
return FlowMatchingOutput(
refined_logits=refined_logits,
num_steps=self.num_steps,
trajectory=trajectory if return_trajectory else None,
)
def compute_loss(
self,
initial_hidden: torch.Tensor,
target_hidden: torch.Tensor,
) -> torch.Tensor:
batch_size = initial_hidden.shape[0]
device = initial_hidden.device
dtype = initial_hidden.dtype
x_0 = self.input_norm(initial_hidden)
x_1 = target_hidden
t = torch.rand(batch_size, device=device, dtype=dtype)
t_expand = t.view(-1, 1, 1)
x_t = (1 - t_expand) * x_0 + t_expand * x_1
target_velocity = x_1 - x_0
step_idx = torch.randint(0, self.num_steps, (1,)).item()
predicted_velocity = self.flow_steps[step_idx](x_t, t)
loss = F.mse_loss(predicted_velocity, target_velocity)
return loss