File size: 5,026 Bytes
d158200 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """AAM Diffusion LLM — Flow Matching Decoder
Alternative to DDPM/DDIM — only 2-3 steps because starting point
is already meaningful (graph-conditioned prediction).
Flow matching = velocity prediction (more stable for text),
doesn't need noise schedule.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class FlowMatchingOutput:
refined_logits: torch.Tensor
num_steps: int
trajectory: Optional[List[torch.Tensor]]
class FlowStep(nn.Module):
"""Single step flow matching — predicts velocity field."""
def __init__(self, d_model: int, time_embed_dim: Optional[int] = None) -> None:
super().__init__()
self.d_model = d_model
self.time_embed_dim = time_embed_dim or d_model // 4
self.time_mlp = nn.Sequential(
nn.Linear(self.time_embed_dim, d_model),
nn.SiLU(),
nn.Linear(d_model, d_model),
)
self.velocity_net = nn.Sequential(
nn.Linear(d_model * 2, d_model),
nn.SiLU(),
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, d_model),
)
self.layer_norm = nn.LayerNorm(d_model)
@staticmethod
def sinusoidal_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb)
emb = t * emb.unsqueeze(0)
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
if dim % 2 == 1:
emb = F.pad(emb, (0, 1))
return emb
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
t_emb = self.sinusoidal_embedding(t, self.time_embed_dim)
t_emb = self.time_mlp(t_emb)
t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)
velocity_input = torch.cat([x, t_emb], dim=-1)
velocity = self.velocity_net(velocity_input)
return velocity
class FlowMatchingDecoder(nn.Module):
"""Flow Matching Decoder — 2-3 step refinement alternative to DDPM/DDIM.
Flow matching formula:
- x_0 = initial prediction (from graph-conditioned denoising)
- x_1 = refined prediction (target)
- dx/dt = v(x, t) — velocity field
- x_{t+dt} = x_t + v(x_t, t) * dt — Euler step
"""
def __init__(self, d_model: int, d_vocab: int, num_steps: int = 3) -> None:
super().__init__()
self.d_model = d_model
self.d_vocab = d_vocab
self.num_steps = max(1, num_steps)
self.logits_to_hidden = nn.Linear(d_vocab, d_model, bias=False)
self.hidden_to_logits = nn.Linear(d_model, d_vocab, bias=False)
self.flow_steps = nn.ModuleList(
[FlowStep(d_model) for _ in range(self.num_steps)]
)
self.input_norm = nn.LayerNorm(d_model)
self.output_norm = nn.LayerNorm(d_model)
self.register_buffer(
"time_schedule",
torch.linspace(0, 1, self.num_steps + 1),
)
def forward(
self,
initial_hidden: torch.Tensor,
return_trajectory: bool = False,
) -> FlowMatchingOutput:
x = self.input_norm(initial_hidden)
trajectory: List[torch.Tensor] = []
if return_trajectory:
trajectory.append(x.clone())
batch_size = x.shape[0]
for step_idx in range(self.num_steps):
t_start = self.time_schedule[step_idx]
t_end = self.time_schedule[step_idx + 1]
dt = t_end - t_start
t = t_start.expand(batch_size).to(x.device)
velocity = self.flow_steps[step_idx](x, t)
x = x + velocity * dt
if return_trajectory:
trajectory.append(x.clone())
x = self.output_norm(x)
refined_logits = self.hidden_to_logits(x)
return FlowMatchingOutput(
refined_logits=refined_logits,
num_steps=self.num_steps,
trajectory=trajectory if return_trajectory else None,
)
def compute_loss(
self,
initial_hidden: torch.Tensor,
target_hidden: torch.Tensor,
) -> torch.Tensor:
batch_size = initial_hidden.shape[0]
device = initial_hidden.device
dtype = initial_hidden.dtype
x_0 = self.input_norm(initial_hidden)
x_1 = target_hidden
t = torch.rand(batch_size, device=device, dtype=dtype)
t_expand = t.view(-1, 1, 1)
x_t = (1 - t_expand) * x_0 + t_expand * x_1
target_velocity = x_1 - x_0
step_idx = torch.randint(0, self.num_steps, (1,)).item()
predicted_velocity = self.flow_steps[step_idx](x_t, t)
loss = F.mse_loss(predicted_velocity, target_velocity)
return loss
|