motion-stream / models /diffloss.py
zirobtc's picture
Upload 2 files
3c212d2 verified
# models/diffloss.py
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from models.diffusion import create_diffusion
# ---------------- utils ----------------
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
return emb
def forward(self, t):
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
class SinPos1D(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, L, device, dtype):
pe = torch.zeros(L, self.dim, device=device, dtype=torch.float32)
pos = torch.arange(0, L, device=device, dtype=torch.float32).unsqueeze(1)
div = torch.exp(torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * (-math.log(10000.0)/self.dim))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe.to(dtype)
# --------------- DiT block (causal) ---------------
class TemporalDiTBlock(nn.Module):
"""
Transformer block with AdaLN (DiT-style), **causal** self-attention over time.
"""
def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
hidden = int(dim * mlp_ratio)
self.ffn = nn.Sequential(
nn.Linear(dim, 2 * hidden, bias=True),
nn.SiLU(),
nn.Linear(2 * hidden, dim, bias=True),
)
# AdaLN params: shift/scale/gate for attn and ffn
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
nn.init.constant_(self.adaLN[-1].weight, 0)
nn.init.constant_(self.adaLN[-1].bias, 0)
def forward(self, x, y, causal_mask):
"""
x: [B, L, D], y: [B, D], causal_mask: [L, L] bool, True = mask (disallow)
"""
s1, sc1, g1, s2, sc2, g2 = self.adaLN(y).chunk(6, dim=-1) # [B, D] each
# attn (causal)
h = modulate(self.norm1(x), s1.unsqueeze(1), sc1.unsqueeze(1))
# torch's attn expects attn_mask shape [L, L] or [B*nH, L, L]; True means -inf
h, _ = self.attn(h, h, h, attn_mask=causal_mask, need_weights=False)
x = x + g1.unsqueeze(1) * h
# ffn
h2 = modulate(self.norm2(x), s2.unsqueeze(1), sc2.unsqueeze(1))
h2 = self.ffn(h2)
x = x + g2.unsqueeze(1) * h2
return x
class FinalLayer(nn.Module):
def __init__(self, dim, out_channels):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(dim, out_channels, bias=True)
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
nn.init.constant_(self.adaLN[-1].weight, 0)
nn.init.constant_(self.adaLN[-1].bias, 0)
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x, c):
shift, scale = self.adaLN(c).chunk(2, dim=-1)
x = modulate(self.norm(x), shift.unsqueeze(1), scale.unsqueeze(1))
return self.linear(x)
# --------------- Temporal DiT (sequence-aware, causal) ---------------
class TemporalDiTAdaLN(nn.Module):
"""
DiT-like denoiser that:
- operates on [B, L, C]
- uses **causal** attention (each position sees only <= t)
- accepts (B, L) via set_sequence_layout for flatten↔sequence reshaping
- returns all positions but we usually **read only the last token** for streaming
"""
def __init__(self, in_channels, model_channels, out_channels, z_channels, depth, n_heads=8,
mlp_ratio=4.0, grad_checkpointing=False):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.z_channels = z_channels
self.depth = depth
self.n_heads = n_heads
self.grad_checkpointing = grad_checkpointing
self.time_embed = TimestepEmbedder(model_channels)
self.cond_embed = nn.Linear(z_channels, model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
self.pos = SinPos1D(model_channels)
self.blocks = nn.ModuleList([
TemporalDiTBlock(model_channels, n_heads=n_heads, mlp_ratio=mlp_ratio)
for _ in range(depth)
])
self.final = FinalLayer(model_channels, out_channels)
self._seq_B = None
self._seq_L = None
self._init_weights()
def _init_weights(self):
def _xav(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.constant_(m.bias, 0)
self.apply(_xav)
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
def set_sequence_layout(self, B, L):
self._seq_B = int(B)
self._seq_L = int(L)
def _flatten_to_seq(self, x_flat, c_flat):
if self._seq_B is None or self._seq_L is None:
B, L = x_flat.shape[0], 1
else:
B, L = self._seq_B, self._seq_L
assert B * L == x_flat.shape[0], f"set_sequence_layout({B},{L}) mismatch"
x = x_flat.view(B, L, -1)
c = c_flat.view(B, L, -1)
return x, c
@staticmethod
def _causal_mask(L, device):
# True where masked (disallowed)
m = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
# MultiheadAttention expects float mask with -inf where we mask.
# But newer PyTorch also supports bool with True=mask. We'll pass bool here.
return m
def forward(self, x_flat, t, c_flat, cfg_scale: float = 1.0):
x, c = self._flatten_to_seq(x_flat, c_flat) # [B, L, C], [B, L, Cz]
B, L, _ = x.shape
x = self.input_proj(x)
pos = self.pos(L, x.device, x.dtype)
x = x + pos.unsqueeze(0)
# pool cond to a single AdaLN vector per batch (like DiT)
t_emb = self.time_embed(t).view(B, L, -1).mean(dim=1) # [B, D]
c_emb = self.cond_embed(c).mean(dim=1) # [B, D]
y = t_emb + c_emb
causal_mask = self._causal_mask(L, x.device)
if self.grad_checkpointing and not torch.jit.is_scripting():
for blk in self.blocks:
x = checkpoint(blk, x, y, causal_mask)
else:
for blk in self.blocks:
x = blk(x, y, causal_mask)
out = self.final(x, y) # [B, L, out_channels]
return out.view(B * L, -1)
def forward_with_cfg(self, x, t, c, cfg_scale):
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, c, cfg_scale=cfg_scale)
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
guided = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([guided, guided], dim=0)
return torch.cat([eps, rest], dim=1)
# --------------- Wrapper (same training API) + streaming helpers ---------------
class DiffLoss(nn.Module):
"""
Diffusion loss with **causal, streamable** temporal DiT denoiser.
Training API unchanged; plus:
- set_sequence_layout(B, L)
- sample_next_token(z_seq, temperature=1.0, cfg=1.0) -> [B, C] (last token)
"""
def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps,
grad_checkpointing=False, learn_sigma=False, n_heads=8, mlp_ratio=4.0):
super().__init__()
self.in_channels = target_channels
self.learn_sigma = learn_sigma
self.net = TemporalDiTAdaLN(
in_channels=target_channels,
model_channels=width,
out_channels=target_channels * 2 if learn_sigma else target_channels,
z_channels=z_channels,
depth=depth,
n_heads=n_heads,
mlp_ratio=mlp_ratio,
grad_checkpointing=grad_checkpointing
)
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
# cached (B,L) for flatten↔sequence
self._B = None
self._L = None
# --- layout for flatten<->sequence ---
def set_sequence_layout(self, B, L):
self._B, self._L = int(B), int(L)
self.net.set_sequence_layout(B, L)
# --- training ---
def forward(self, target, z, mask=None):
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
loss_dict = self.train_diffusion.training_losses(self.net, target, t, dict(c=z))
loss, pred_xstart = loss_dict["loss"], loss_dict["pred_xstart"]
if mask is not None:
loss = (loss * mask).sum() / mask.sum()
return loss.mean(), pred_xstart
# --- full sequence sampling (kept for compatibility) ---
def sample(self, z, temperature=1.0, cfg=1.0):
if cfg != 1.0:
noise = torch.randn(z.shape[0] // 2, self.in_channels, device=z.device)
noise = torch.cat([noise, noise], dim=0)
sample_fn = self.net.forward_with_cfg
kwargs = dict(c=z, cfg_scale=cfg)
else:
noise = torch.randn(z.shape[0], self.in_channels, device=z.device)
sample_fn = self.net.forward
kwargs = dict(c=z)
return self.gen_diffusion.p_sample_loop(
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs,
progress=False, temperature=temperature
)
# --- STREAMING: sample only the **last token** of current window ---
@torch.no_grad()
def sample_next_token(self, z_seq, temperature=1.0, cfg=1.0):
"""
z_seq: [B, L, Cz] AR conditions for the current streaming window (history + 1 step).
Call set_sequence_layout(B, L) first.
Returns: next_token: [B, C] (the last position’s denoised sample).
Mechanism: denoise **entire window** with causal attention and read the last index only.
"""
assert self._B is not None and self._L is not None, "Call set_sequence_layout(B, L) first."
B, L, Cz = z_seq.shape
assert B == self._B and L == self._L, "z_seq shape must match set_sequence_layout."
z_flat = z_seq.reshape(B * L, Cz)
if cfg != 1.0:
noise = torch.randn((B * L) // 2, self.in_channels, device=z_seq.device)
noise = torch.cat([noise, noise], dim=0)
sample_fn = self.net.forward_with_cfg
kwargs = dict(c=z_flat, cfg_scale=cfg)
else:
noise = torch.randn(B * L, self.in_channels, device=z_seq.device)
sample_fn = self.net.forward
kwargs = dict(c=z_flat)
x = self.gen_diffusion.p_sample_loop(
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs,
progress=False, temperature=temperature
) # [B*L, C]
x_seq = x.view(B, L, self.in_channels)
return x_seq[:, -1, :] # last token only