# models/diffloss.py import math import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from models.diffusion import create_diffusion # ---------------- utils ---------------- def modulate(x, shift, scale): return x * (1 + scale) + shift class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32) / half).to(t.device) args = t[:, None].float() * freqs[None] emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1) return emb def forward(self, t): return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size)) class SinPos1D(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, L, device, dtype): pe = torch.zeros(L, self.dim, device=device, dtype=torch.float32) pos = torch.arange(0, L, device=device, dtype=torch.float32).unsqueeze(1) div = torch.exp(torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * (-math.log(10000.0)/self.dim)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) return pe.to(dtype) # --------------- DiT block (causal) --------------- class TemporalDiTBlock(nn.Module): """ Transformer block with AdaLN (DiT-style), **causal** self-attention over time. """ def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0): super().__init__() self.dim = dim self.n_heads = n_heads self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True) self.norm2 = nn.LayerNorm(dim, eps=1e-6) hidden = int(dim * mlp_ratio) self.ffn = nn.Sequential( nn.Linear(dim, 2 * hidden, bias=True), nn.SiLU(), nn.Linear(2 * hidden, dim, bias=True), ) # AdaLN params: shift/scale/gate for attn and ffn self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) nn.init.constant_(self.adaLN[-1].weight, 0) nn.init.constant_(self.adaLN[-1].bias, 0) def forward(self, x, y, causal_mask): """ x: [B, L, D], y: [B, D], causal_mask: [L, L] bool, True = mask (disallow) """ s1, sc1, g1, s2, sc2, g2 = self.adaLN(y).chunk(6, dim=-1) # [B, D] each # attn (causal) h = modulate(self.norm1(x), s1.unsqueeze(1), sc1.unsqueeze(1)) # torch's attn expects attn_mask shape [L, L] or [B*nH, L, L]; True means -inf h, _ = self.attn(h, h, h, attn_mask=causal_mask, need_weights=False) x = x + g1.unsqueeze(1) * h # ffn h2 = modulate(self.norm2(x), s2.unsqueeze(1), sc2.unsqueeze(1)) h2 = self.ffn(h2) x = x + g2.unsqueeze(1) * h2 return x class FinalLayer(nn.Module): def __init__(self, dim, out_channels): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(dim, out_channels, bias=True) self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) nn.init.constant_(self.adaLN[-1].weight, 0) nn.init.constant_(self.adaLN[-1].bias, 0) nn.init.constant_(self.linear.weight, 0) nn.init.constant_(self.linear.bias, 0) def forward(self, x, c): shift, scale = self.adaLN(c).chunk(2, dim=-1) x = modulate(self.norm(x), shift.unsqueeze(1), scale.unsqueeze(1)) return self.linear(x) # --------------- Temporal DiT (sequence-aware, causal) --------------- class TemporalDiTAdaLN(nn.Module): """ DiT-like denoiser that: - operates on [B, L, C] - uses **causal** attention (each position sees only <= t) - accepts (B, L) via set_sequence_layout for flatten↔sequence reshaping - returns all positions but we usually **read only the last token** for streaming """ def __init__(self, in_channels, model_channels, out_channels, z_channels, depth, n_heads=8, mlp_ratio=4.0, grad_checkpointing=False): super().__init__() self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.z_channels = z_channels self.depth = depth self.n_heads = n_heads self.grad_checkpointing = grad_checkpointing self.time_embed = TimestepEmbedder(model_channels) self.cond_embed = nn.Linear(z_channels, model_channels) self.input_proj = nn.Linear(in_channels, model_channels) self.pos = SinPos1D(model_channels) self.blocks = nn.ModuleList([ TemporalDiTBlock(model_channels, n_heads=n_heads, mlp_ratio=mlp_ratio) for _ in range(depth) ]) self.final = FinalLayer(model_channels, out_channels) self._seq_B = None self._seq_L = None self._init_weights() def _init_weights(self): def _xav(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) self.apply(_xav) nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) def set_sequence_layout(self, B, L): self._seq_B = int(B) self._seq_L = int(L) def _flatten_to_seq(self, x_flat, c_flat): if self._seq_B is None or self._seq_L is None: B, L = x_flat.shape[0], 1 else: B, L = self._seq_B, self._seq_L assert B * L == x_flat.shape[0], f"set_sequence_layout({B},{L}) mismatch" x = x_flat.view(B, L, -1) c = c_flat.view(B, L, -1) return x, c @staticmethod def _causal_mask(L, device): # True where masked (disallowed) m = torch.ones(L, L, device=device, dtype=torch.bool).triu(1) # MultiheadAttention expects float mask with -inf where we mask. # But newer PyTorch also supports bool with True=mask. We'll pass bool here. return m def forward(self, x_flat, t, c_flat, cfg_scale: float = 1.0): x, c = self._flatten_to_seq(x_flat, c_flat) # [B, L, C], [B, L, Cz] B, L, _ = x.shape x = self.input_proj(x) pos = self.pos(L, x.device, x.dtype) x = x + pos.unsqueeze(0) # pool cond to a single AdaLN vector per batch (like DiT) t_emb = self.time_embed(t).view(B, L, -1).mean(dim=1) # [B, D] c_emb = self.cond_embed(c).mean(dim=1) # [B, D] y = t_emb + c_emb causal_mask = self._causal_mask(L, x.device) if self.grad_checkpointing and not torch.jit.is_scripting(): for blk in self.blocks: x = checkpoint(blk, x, y, causal_mask) else: for blk in self.blocks: x = blk(x, y, causal_mask) out = self.final(x, y) # [B, L, out_channels] return out.view(B * L, -1) def forward_with_cfg(self, x, t, c, cfg_scale): half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, t, c, cfg_scale=cfg_scale) eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) guided = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([guided, guided], dim=0) return torch.cat([eps, rest], dim=1) # --------------- Wrapper (same training API) + streaming helpers --------------- class DiffLoss(nn.Module): """ Diffusion loss with **causal, streamable** temporal DiT denoiser. Training API unchanged; plus: - set_sequence_layout(B, L) - sample_next_token(z_seq, temperature=1.0, cfg=1.0) -> [B, C] (last token) """ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False, learn_sigma=False, n_heads=8, mlp_ratio=4.0): super().__init__() self.in_channels = target_channels self.learn_sigma = learn_sigma self.net = TemporalDiTAdaLN( in_channels=target_channels, model_channels=width, out_channels=target_channels * 2 if learn_sigma else target_channels, z_channels=z_channels, depth=depth, n_heads=n_heads, mlp_ratio=mlp_ratio, grad_checkpointing=grad_checkpointing ) self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine") self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine") # cached (B,L) for flatten↔sequence self._B = None self._L = None # --- layout for flatten<->sequence --- def set_sequence_layout(self, B, L): self._B, self._L = int(B), int(L) self.net.set_sequence_layout(B, L) # --- training --- def forward(self, target, z, mask=None): t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) loss_dict = self.train_diffusion.training_losses(self.net, target, t, dict(c=z)) loss, pred_xstart = loss_dict["loss"], loss_dict["pred_xstart"] if mask is not None: loss = (loss * mask).sum() / mask.sum() return loss.mean(), pred_xstart # --- full sequence sampling (kept for compatibility) --- def sample(self, z, temperature=1.0, cfg=1.0): if cfg != 1.0: noise = torch.randn(z.shape[0] // 2, self.in_channels, device=z.device) noise = torch.cat([noise, noise], dim=0) sample_fn = self.net.forward_with_cfg kwargs = dict(c=z, cfg_scale=cfg) else: noise = torch.randn(z.shape[0], self.in_channels, device=z.device) sample_fn = self.net.forward kwargs = dict(c=z) return self.gen_diffusion.p_sample_loop( sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs, progress=False, temperature=temperature ) # --- STREAMING: sample only the **last token** of current window --- @torch.no_grad() def sample_next_token(self, z_seq, temperature=1.0, cfg=1.0): """ z_seq: [B, L, Cz] AR conditions for the current streaming window (history + 1 step). Call set_sequence_layout(B, L) first. Returns: next_token: [B, C] (the last position’s denoised sample). Mechanism: denoise **entire window** with causal attention and read the last index only. """ assert self._B is not None and self._L is not None, "Call set_sequence_layout(B, L) first." B, L, Cz = z_seq.shape assert B == self._B and L == self._L, "z_seq shape must match set_sequence_layout." z_flat = z_seq.reshape(B * L, Cz) if cfg != 1.0: noise = torch.randn((B * L) // 2, self.in_channels, device=z_seq.device) noise = torch.cat([noise, noise], dim=0) sample_fn = self.net.forward_with_cfg kwargs = dict(c=z_flat, cfg_scale=cfg) else: noise = torch.randn(B * L, self.in_channels, device=z_seq.device) sample_fn = self.net.forward kwargs = dict(c=z_flat) x = self.gen_diffusion.p_sample_loop( sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs, progress=False, temperature=temperature ) # [B*L, C] x_seq = x.view(B, L, self.in_channels) return x_seq[:, -1, :] # last token only