|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from models.diffusion import create_diffusion |
|
|
|
|
|
|
|
|
|
|
|
def modulate(x, shift, scale): |
|
|
return x * (1 + scale) + shift |
|
|
|
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
def __init__(self, hidden_size, frequency_embedding_size=256): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
|
) |
|
|
self.frequency_embedding_size = frequency_embedding_size |
|
|
|
|
|
@staticmethod |
|
|
def timestep_embedding(t, dim, max_period=10000): |
|
|
half = dim // 2 |
|
|
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32) / half).to(t.device) |
|
|
args = t[:, None].float() * freqs[None] |
|
|
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
if dim % 2: |
|
|
emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1) |
|
|
return emb |
|
|
|
|
|
def forward(self, t): |
|
|
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size)) |
|
|
|
|
|
|
|
|
class SinPos1D(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
def forward(self, L, device, dtype): |
|
|
pe = torch.zeros(L, self.dim, device=device, dtype=torch.float32) |
|
|
pos = torch.arange(0, L, device=device, dtype=torch.float32).unsqueeze(1) |
|
|
div = torch.exp(torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * (-math.log(10000.0)/self.dim)) |
|
|
pe[:, 0::2] = torch.sin(pos * div) |
|
|
pe[:, 1::2] = torch.cos(pos * div) |
|
|
return pe.to(dtype) |
|
|
|
|
|
|
|
|
|
|
|
class TemporalDiTBlock(nn.Module): |
|
|
""" |
|
|
Transformer block with AdaLN (DiT-style), **causal** self-attention over time. |
|
|
""" |
|
|
def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.n_heads = n_heads |
|
|
self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True) |
|
|
self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
|
|
hidden = int(dim * mlp_ratio) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(dim, 2 * hidden, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(2 * hidden, dim, bias=True), |
|
|
) |
|
|
|
|
|
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) |
|
|
nn.init.constant_(self.adaLN[-1].weight, 0) |
|
|
nn.init.constant_(self.adaLN[-1].bias, 0) |
|
|
|
|
|
def forward(self, x, y, causal_mask): |
|
|
""" |
|
|
x: [B, L, D], y: [B, D], causal_mask: [L, L] bool, True = mask (disallow) |
|
|
""" |
|
|
s1, sc1, g1, s2, sc2, g2 = self.adaLN(y).chunk(6, dim=-1) |
|
|
|
|
|
|
|
|
h = modulate(self.norm1(x), s1.unsqueeze(1), sc1.unsqueeze(1)) |
|
|
|
|
|
h, _ = self.attn(h, h, h, attn_mask=causal_mask, need_weights=False) |
|
|
x = x + g1.unsqueeze(1) * h |
|
|
|
|
|
|
|
|
h2 = modulate(self.norm2(x), s2.unsqueeze(1), sc2.unsqueeze(1)) |
|
|
h2 = self.ffn(h2) |
|
|
x = x + g2.unsqueeze(1) * h2 |
|
|
return x |
|
|
|
|
|
|
|
|
class FinalLayer(nn.Module): |
|
|
def __init__(self, dim, out_channels): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
|
self.linear = nn.Linear(dim, out_channels, bias=True) |
|
|
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) |
|
|
nn.init.constant_(self.adaLN[-1].weight, 0) |
|
|
nn.init.constant_(self.adaLN[-1].bias, 0) |
|
|
nn.init.constant_(self.linear.weight, 0) |
|
|
nn.init.constant_(self.linear.bias, 0) |
|
|
|
|
|
def forward(self, x, c): |
|
|
shift, scale = self.adaLN(c).chunk(2, dim=-1) |
|
|
x = modulate(self.norm(x), shift.unsqueeze(1), scale.unsqueeze(1)) |
|
|
return self.linear(x) |
|
|
|
|
|
|
|
|
|
|
|
class TemporalDiTAdaLN(nn.Module): |
|
|
""" |
|
|
DiT-like denoiser that: |
|
|
- operates on [B, L, C] |
|
|
- uses **causal** attention (each position sees only <= t) |
|
|
- accepts (B, L) via set_sequence_layout for flatten↔sequence reshaping |
|
|
- returns all positions but we usually **read only the last token** for streaming |
|
|
""" |
|
|
def __init__(self, in_channels, model_channels, out_channels, z_channels, depth, n_heads=8, |
|
|
mlp_ratio=4.0, grad_checkpointing=False): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
self.model_channels = model_channels |
|
|
self.out_channels = out_channels |
|
|
self.z_channels = z_channels |
|
|
self.depth = depth |
|
|
self.n_heads = n_heads |
|
|
self.grad_checkpointing = grad_checkpointing |
|
|
|
|
|
self.time_embed = TimestepEmbedder(model_channels) |
|
|
self.cond_embed = nn.Linear(z_channels, model_channels) |
|
|
self.input_proj = nn.Linear(in_channels, model_channels) |
|
|
self.pos = SinPos1D(model_channels) |
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
TemporalDiTBlock(model_channels, n_heads=n_heads, mlp_ratio=mlp_ratio) |
|
|
for _ in range(depth) |
|
|
]) |
|
|
self.final = FinalLayer(model_channels, out_channels) |
|
|
|
|
|
self._seq_B = None |
|
|
self._seq_L = None |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
def _xav(m): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: nn.init.constant_(m.bias, 0) |
|
|
self.apply(_xav) |
|
|
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) |
|
|
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
def set_sequence_layout(self, B, L): |
|
|
self._seq_B = int(B) |
|
|
self._seq_L = int(L) |
|
|
|
|
|
def _flatten_to_seq(self, x_flat, c_flat): |
|
|
if self._seq_B is None or self._seq_L is None: |
|
|
B, L = x_flat.shape[0], 1 |
|
|
else: |
|
|
B, L = self._seq_B, self._seq_L |
|
|
assert B * L == x_flat.shape[0], f"set_sequence_layout({B},{L}) mismatch" |
|
|
x = x_flat.view(B, L, -1) |
|
|
c = c_flat.view(B, L, -1) |
|
|
return x, c |
|
|
|
|
|
@staticmethod |
|
|
def _causal_mask(L, device): |
|
|
|
|
|
m = torch.ones(L, L, device=device, dtype=torch.bool).triu(1) |
|
|
|
|
|
|
|
|
return m |
|
|
|
|
|
def forward(self, x_flat, t, c_flat, cfg_scale: float = 1.0): |
|
|
x, c = self._flatten_to_seq(x_flat, c_flat) |
|
|
B, L, _ = x.shape |
|
|
|
|
|
x = self.input_proj(x) |
|
|
pos = self.pos(L, x.device, x.dtype) |
|
|
x = x + pos.unsqueeze(0) |
|
|
|
|
|
|
|
|
t_emb = self.time_embed(t).view(B, L, -1).mean(dim=1) |
|
|
c_emb = self.cond_embed(c).mean(dim=1) |
|
|
y = t_emb + c_emb |
|
|
|
|
|
causal_mask = self._causal_mask(L, x.device) |
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
|
for blk in self.blocks: |
|
|
x = checkpoint(blk, x, y, causal_mask) |
|
|
else: |
|
|
for blk in self.blocks: |
|
|
x = blk(x, y, causal_mask) |
|
|
|
|
|
out = self.final(x, y) |
|
|
return out.view(B * L, -1) |
|
|
|
|
|
def forward_with_cfg(self, x, t, c, cfg_scale): |
|
|
half = x[: len(x) // 2] |
|
|
combined = torch.cat([half, half], dim=0) |
|
|
model_out = self.forward(combined, t, c, cfg_scale=cfg_scale) |
|
|
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] |
|
|
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
|
|
guided = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
|
eps = torch.cat([guided, guided], dim=0) |
|
|
return torch.cat([eps, rest], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
class DiffLoss(nn.Module): |
|
|
""" |
|
|
Diffusion loss with **causal, streamable** temporal DiT denoiser. |
|
|
Training API unchanged; plus: |
|
|
- set_sequence_layout(B, L) |
|
|
- sample_next_token(z_seq, temperature=1.0, cfg=1.0) -> [B, C] (last token) |
|
|
""" |
|
|
def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, |
|
|
grad_checkpointing=False, learn_sigma=False, n_heads=8, mlp_ratio=4.0): |
|
|
super().__init__() |
|
|
self.in_channels = target_channels |
|
|
self.learn_sigma = learn_sigma |
|
|
|
|
|
self.net = TemporalDiTAdaLN( |
|
|
in_channels=target_channels, |
|
|
model_channels=width, |
|
|
out_channels=target_channels * 2 if learn_sigma else target_channels, |
|
|
z_channels=z_channels, |
|
|
depth=depth, |
|
|
n_heads=n_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
grad_checkpointing=grad_checkpointing |
|
|
) |
|
|
|
|
|
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine") |
|
|
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine") |
|
|
|
|
|
|
|
|
self._B = None |
|
|
self._L = None |
|
|
|
|
|
|
|
|
def set_sequence_layout(self, B, L): |
|
|
self._B, self._L = int(B), int(L) |
|
|
self.net.set_sequence_layout(B, L) |
|
|
|
|
|
|
|
|
def forward(self, target, z, mask=None): |
|
|
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) |
|
|
loss_dict = self.train_diffusion.training_losses(self.net, target, t, dict(c=z)) |
|
|
loss, pred_xstart = loss_dict["loss"], loss_dict["pred_xstart"] |
|
|
if mask is not None: |
|
|
loss = (loss * mask).sum() / mask.sum() |
|
|
return loss.mean(), pred_xstart |
|
|
|
|
|
|
|
|
def sample(self, z, temperature=1.0, cfg=1.0): |
|
|
if cfg != 1.0: |
|
|
noise = torch.randn(z.shape[0] // 2, self.in_channels, device=z.device) |
|
|
noise = torch.cat([noise, noise], dim=0) |
|
|
sample_fn = self.net.forward_with_cfg |
|
|
kwargs = dict(c=z, cfg_scale=cfg) |
|
|
else: |
|
|
noise = torch.randn(z.shape[0], self.in_channels, device=z.device) |
|
|
sample_fn = self.net.forward |
|
|
kwargs = dict(c=z) |
|
|
|
|
|
return self.gen_diffusion.p_sample_loop( |
|
|
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs, |
|
|
progress=False, temperature=temperature |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_next_token(self, z_seq, temperature=1.0, cfg=1.0): |
|
|
""" |
|
|
z_seq: [B, L, Cz] AR conditions for the current streaming window (history + 1 step). |
|
|
Call set_sequence_layout(B, L) first. |
|
|
Returns: next_token: [B, C] (the last position’s denoised sample). |
|
|
Mechanism: denoise **entire window** with causal attention and read the last index only. |
|
|
""" |
|
|
assert self._B is not None and self._L is not None, "Call set_sequence_layout(B, L) first." |
|
|
B, L, Cz = z_seq.shape |
|
|
assert B == self._B and L == self._L, "z_seq shape must match set_sequence_layout." |
|
|
|
|
|
z_flat = z_seq.reshape(B * L, Cz) |
|
|
|
|
|
if cfg != 1.0: |
|
|
noise = torch.randn((B * L) // 2, self.in_channels, device=z_seq.device) |
|
|
noise = torch.cat([noise, noise], dim=0) |
|
|
sample_fn = self.net.forward_with_cfg |
|
|
kwargs = dict(c=z_flat, cfg_scale=cfg) |
|
|
else: |
|
|
noise = torch.randn(B * L, self.in_channels, device=z_seq.device) |
|
|
sample_fn = self.net.forward |
|
|
kwargs = dict(c=z_flat) |
|
|
|
|
|
x = self.gen_diffusion.p_sample_loop( |
|
|
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs, |
|
|
progress=False, temperature=temperature |
|
|
) |
|
|
|
|
|
x_seq = x.view(B, L, self.in_channels) |
|
|
return x_seq[:, -1, :] |
|
|
|