File size: 40,768 Bytes
b47a1ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 | """
References:
- DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
- Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
- Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
"""
from typing import Optional, Literal
import torch
from torch import nn
from .rotary_embedding_torch import RotaryEmbedding
from einops import rearrange
from .attention import SpatialAxialAttention, TemporalAxialAttention
from timm.models.vision_transformer import Mlp
from timm.layers.helpers import to_2tuple
import math
from collections import namedtuple
from typing import Optional, Callable
def modulate(x, shift, scale):
fixed_dims = [1] * len(shift.shape[1:])
shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
while shift.dim() < x.dim():
shift = shift.unsqueeze(-2)
scale = scale.unsqueeze(-2)
return x * (1 + scale) + shift
def gate(x, g):
fixed_dims = [1] * len(g.shape[1:])
g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
while g.dim() < x.dim():
g = g.unsqueeze(-2)
return g * x
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
img_height=256,
img_width=256,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
):
super().__init__()
img_size = (img_height, img_width)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x, random_sample=False):
B, C, H, W = x.shape
assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = rearrange(x, "B C H W -> B (H W) C")
else:
x = rearrange(x, "B C H W -> B H W C")
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.freq_type = freq_type
@staticmethod
def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
if freq_type == 'time_step':
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
elif freq_type == 'spatial': # ~(-5 5)
freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi
elif freq_type == 'angle': # 0-360
freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type)
t_emb = self.mlp(t_freq)
return t_emb
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
MEMORY_TYPE_NAMES = ("anchor", "dynamic", "revisit")
MEMORY_TYPE_ANCHOR = 0
MEMORY_TYPE_DYNAMIC = 1
MEMORY_TYPE_REVISIT = 2
class MemoryTokenCrossAttention(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, num_memory_types=3):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.num_heads = num_heads
self.num_memory_types = num_memory_types
self.norm_q = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm_mem = nn.LayerNorm(hidden_size, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
self.last_gate_mean = None
self.last_delta_ratio = None
self.last_valid_fraction = None
self.last_type_gate_mean = None
for type_name in MEMORY_TYPE_NAMES[:num_memory_types]:
setattr(self, f"last_type_gate_{type_name}_mean", None)
nn.init.normal_(self.memory_type_embed.weight, std=0.02)
self.reset_identity_init()
def reset_identity_init(self):
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.memory_type_gate[-1].weight, 0)
nn.init.constant_(self.memory_type_gate[-1].bias, 0)
def _attend(self, query, memory_tokens, memory_token_mask=None, memory_token_gate=None):
if memory_token_mask is None and memory_token_gate is None:
out, _ = self.attn(query, memory_tokens, memory_tokens, need_weights=False)
return out, None
if memory_token_mask is None:
memory_token_mask = torch.ones(
memory_tokens.shape[:2],
device=memory_tokens.device,
dtype=torch.bool,
)
else:
memory_token_mask = memory_token_mask.bool()
gate_tensor = None
if memory_token_gate is not None:
if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:2]):
raise ValueError(
f"memory_token_gate must have shape {tuple(memory_tokens.shape[:2])}, "
f"got {tuple(memory_token_gate.shape)}"
)
gate_tensor = memory_token_gate.to(device=memory_tokens.device, dtype=query.dtype)
memory_token_mask = memory_token_mask & (gate_tensor > 0)
valid_rows = memory_token_mask.any(dim=1)
out = torch.zeros_like(query)
if valid_rows.any():
attn_mask = None
key_padding_mask = ~memory_token_mask[valid_rows]
if gate_tensor is not None:
gate_bias = torch.log(gate_tensor[valid_rows].clamp_min(1.0e-6))
gate_bias = gate_bias[:, None, :].expand(-1, query.shape[1], -1)
attn_mask = gate_bias.repeat_interleave(self.num_heads, dim=0)
float_padding_mask = torch.zeros_like(gate_tensor[valid_rows], dtype=query.dtype)
key_padding_mask = float_padding_mask.masked_fill(key_padding_mask, float("-inf"))
attended, _ = self.attn(
query[valid_rows],
memory_tokens[valid_rows],
memory_tokens[valid_rows],
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
need_weights=False,
)
out[valid_rows] = attended.to(out.dtype)
return out, valid_rows
def _apply_memory_type(self, memory_tokens, memory_type_ids):
if memory_type_ids is None:
return memory_tokens
memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
type_embed = self.memory_type_embed(memory_type_ids).to(memory_tokens.dtype)
type_scale = self.memory_type_scale[memory_type_ids].to(memory_tokens.dtype)
while type_embed.dim() < memory_tokens.dim():
type_embed = type_embed.unsqueeze(0)
type_scale = type_scale.unsqueeze(0)
return memory_tokens * type_scale + type_embed
def _store_type_gate_diagnostics(self, stage_gate):
with torch.no_grad():
detached = stage_gate.detach().float()
self.last_type_gate_mean = detached.mean()
for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]):
setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean())
def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
if memory_type_ids is None:
return None
memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
self._store_type_gate_diagnostics(stage_gate)
if memory_tokens.dim() == 4:
batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
if memory_type_ids.dim() == 1:
gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, num_frames, num_tokens)
elif tuple(memory_type_ids.shape) == (batch_size, num_frames, num_tokens):
gather_ids = memory_type_ids
else:
raise ValueError(
"rank-4 memory_type_ids must have shape (M,) or (B,T,M), "
f"got {tuple(memory_type_ids.shape)}"
)
return torch.gather(stage_gate, dim=-1, index=gather_ids)
if memory_tokens.dim() == 3:
batch_size, num_tokens = memory_tokens.shape[:2]
if memory_type_ids.dim() != 1:
raise ValueError("rank-3 memory_type_ids must have shape (M,)")
gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, stage_gate.shape[1], num_tokens)
return torch.gather(stage_gate, dim=-1, index=gather_ids).mean(dim=1)
raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}")
def _combine_memory_gate(self, memory_tokens, memory_token_gate, type_stage_gate):
combined_gate = type_stage_gate
if memory_token_gate is not None:
if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:-1]):
raise ValueError(
f"memory_token_gate must have shape {tuple(memory_tokens.shape[:-1])}, "
f"got {tuple(memory_token_gate.shape)}"
)
stream_gate = memory_token_gate.to(device=memory_tokens.device, dtype=memory_tokens.dtype)
combined_gate = stream_gate if combined_gate is None else combined_gate * stream_gate
return combined_gate
def _valid_mask(self, valid_rows, batch_size, num_frames, dtype, device):
if valid_rows is None:
return None
valid_rows = valid_rows.to(device=device, dtype=dtype)
if valid_rows.numel() == batch_size:
return valid_rows.view(batch_size, 1, 1, 1, 1)
if valid_rows.numel() == batch_size * num_frames:
return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None]
raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}")
def _gate_valid_mask(self, valid_rows, batch_size, num_frames, dtype, device):
if valid_rows is None:
return None
valid_rows = valid_rows.to(device=device, dtype=dtype)
if valid_rows.numel() == batch_size:
return valid_rows.view(batch_size, 1, 1)
if valid_rows.numel() == batch_size * num_frames:
return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None]
raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}")
def _residual_gate(self, residual_gate, batch_size, num_frames, dtype, device):
if residual_gate is None:
return None
if not torch.is_tensor(residual_gate):
return torch.tensor(float(residual_gate), dtype=dtype, device=device).view(1, 1, 1, 1, 1)
gate_tensor = residual_gate.to(device=device, dtype=dtype)
if gate_tensor.dim() == 0:
gate_tensor = gate_tensor.view(1, 1, 1, 1, 1)
elif gate_tensor.dim() == 1:
if gate_tensor.numel() == batch_size:
gate_tensor = gate_tensor.view(batch_size, 1, 1, 1, 1)
elif gate_tensor.numel() == batch_size * num_frames:
gate_tensor = rearrange(gate_tensor, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None]
else:
raise ValueError(f"residual_gate has incompatible shape: {tuple(gate_tensor.shape)}")
elif gate_tensor.dim() == 2:
if tuple(gate_tensor.shape) != (batch_size, num_frames):
raise ValueError(f"residual_gate must have shape (B,T), got {tuple(gate_tensor.shape)}")
gate_tensor = gate_tensor[:, :, None, None, None]
elif gate_tensor.dim() == 3:
if tuple(gate_tensor.shape[:2]) != (batch_size, num_frames):
raise ValueError(f"residual_gate must start with (B,T), got {tuple(gate_tensor.shape)}")
gate_tensor = gate_tensor[:, :, :, None, None]
else:
while gate_tensor.dim() < 5:
gate_tensor = gate_tensor.unsqueeze(-1)
return gate_tensor
def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows):
with torch.no_grad():
batch_size, num_frames = base.shape[:2]
gate_values = torch.cat(
[gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()],
dim=-1,
)
gate_mask = self._gate_valid_mask(
valid_rows,
batch_size,
num_frames,
dtype=gate_values.dtype,
device=gate_values.device,
)
if gate_mask is not None:
gate_values = gate_values * gate_mask
self.last_valid_fraction = valid_rows.detach().float().mean()
valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0)
self.last_gate_mean = gate_values.sum() / valid_count
else:
self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32)
self.last_gate_mean = gate_values.mean()
delta_norm = (output.detach().float() - base.detach().float()).norm()
base_norm = base.detach().float().norm()
self.last_delta_ratio = delta_norm / (base_norm + 1e-6)
def forward(
self,
x,
c,
memory_tokens,
memory_token_mask=None,
residual_base=None,
return_delta=False,
residual_gate=None,
memory_type_ids=None,
memory_token_gate=None,
):
B, T, H, W, D = x.shape
if residual_base is None:
residual_base = x
m_shift_msa, m_scale_msa, m_gate_msa, m_shift_mlp, m_scale_mlp, m_gate_mlp = (
self.adaLN_modulation(c).chunk(6, dim=-1)
)
query_source = modulate(self.norm_q(x), m_shift_msa, m_scale_msa)
type_stage_gate = self._type_stage_gate(c, memory_tokens, memory_type_ids)
effective_token_gate = self._combine_memory_gate(memory_tokens, memory_token_gate, type_stage_gate)
if memory_tokens.dim() == 3:
query = rearrange(query_source, "b t h w d -> b (t h w) d")
memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids)
valid_rows = None
if memory_token_mask is not None:
if tuple(memory_token_mask.shape) != tuple(memory_tokens.shape[:2]):
raise ValueError(
f"legacy memory mask must have shape {tuple(memory_tokens.shape[:2])}, "
f"got {tuple(memory_token_mask.shape)}"
)
out, valid_rows = self._attend(
query,
memory_tokens,
memory_token_mask=memory_token_mask,
memory_token_gate=effective_token_gate,
)
out = rearrange(out, "b (t h w) d -> b t h w d", t=T, h=H, w=W)
elif memory_tokens.dim() == 4:
assert memory_tokens.shape[:2] == (B, T), (
f"per-frame memory tokens must have shape (B, T, M, D), got {tuple(memory_tokens.shape)}"
)
query = rearrange(query_source, "b t h w d -> (b t) (h w) d")
memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids)
memory_tokens = rearrange(memory_tokens, "b t m d -> (b t) m d")
if effective_token_gate is not None:
effective_token_gate = rearrange(effective_token_gate, "b t m -> (b t) m")
valid_rows = None
if memory_token_mask is not None:
expected_mask_shape = (B, T, memory_tokens.shape[1])
if tuple(memory_token_mask.shape) != expected_mask_shape:
raise ValueError(
f"per-frame memory mask must have shape {expected_mask_shape}, "
f"got {tuple(memory_token_mask.shape)}"
)
memory_token_mask = rearrange(memory_token_mask.bool(), "b t m -> (b t) m")
out, valid_rows = self._attend(
query,
memory_tokens,
memory_token_mask=memory_token_mask,
memory_token_gate=effective_token_gate,
)
out = rearrange(out, "(b t) (h w) d -> b t h w d", b=B, t=T, h=H, w=W)
else:
raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}")
valid_mask = self._valid_mask(valid_rows, B, T, dtype=out.dtype, device=out.device)
residual_gate_tensor = self._residual_gate(residual_gate, B, T, dtype=out.dtype, device=out.device)
attn_delta = gate(out, m_gate_msa)
if valid_mask is not None:
attn_delta = attn_delta * valid_mask
if residual_gate_tensor is not None:
attn_delta = attn_delta * residual_gate_tensor
output = residual_base + attn_delta
mlp_delta = gate(self.mlp(modulate(self.norm_mlp(output), m_shift_mlp, m_scale_mlp)), m_gate_mlp)
if valid_mask is not None:
mlp_delta = mlp_delta * valid_mask
if residual_gate_tensor is not None:
mlp_delta = mlp_delta * residual_gate_tensor
output = output + mlp_delta
self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows)
if return_delta:
return attn_delta + mlp_delta
return output
class SpatioTemporalDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
reference_length,
mlp_ratio=4.0,
is_causal=True,
spatial_rotary_emb: Optional[RotaryEmbedding] = None,
temporal_rotary_emb: Optional[RotaryEmbedding] = None,
use_memory_token_cross_attention=False,
ref_mode='sequential'
):
super().__init__()
self.is_causal = is_causal
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.s_attn = SpatialAxialAttention(
hidden_size,
heads=num_heads,
dim_head=hidden_size // num_heads,
rotary_emb=spatial_rotary_emb
)
self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.s_mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.t_attn = TemporalAxialAttention(
hidden_size,
heads=num_heads,
dim_head=hidden_size // num_heads,
is_causal=is_causal,
rotary_emb=temporal_rotary_emb,
reference_length=reference_length
)
self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.t_mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.reference_length = reference_length
self.use_memory_token_cross_attention = use_memory_token_cross_attention
if self.use_memory_token_cross_attention:
self.memory_token_cross_attn = MemoryTokenCrossAttention(hidden_size, num_heads, mlp_ratio=mlp_ratio)
self.ref_mode = ref_mode
if self.ref_mode == 'parallel':
self.parallel_map = nn.Linear(hidden_size, hidden_size)
def _expand_memory_stream(self, tokens, mask, stream_gate, type_idx, batch_size, num_frames):
if tokens is None or tokens.shape[-2] == 0:
return None
if tokens.dim() == 3:
if tokens.shape[0] != batch_size:
raise ValueError(f"rank-3 memory tokens must start with B={batch_size}, got {tuple(tokens.shape)}")
tokens = tokens[:, None].expand(-1, num_frames, -1, -1)
if mask is None:
mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool)
elif mask.dim() == 2:
mask = mask[:, None].expand(-1, num_frames, -1)
elif mask.dim() != 3:
raise ValueError(f"rank-3 stream mask must have rank 2 or 3, got {tuple(mask.shape)}")
elif tokens.dim() == 4:
if tuple(tokens.shape[:2]) != (batch_size, num_frames):
raise ValueError(
f"rank-4 memory tokens must start with (B,T)={(batch_size, num_frames)}, "
f"got {tuple(tokens.shape)}"
)
if mask is None:
mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool)
elif mask.dim() != 3:
raise ValueError(f"rank-4 stream mask must have rank 3, got {tuple(mask.shape)}")
else:
raise ValueError(f"memory stream tokens must be rank 3 or 4, got rank {tokens.dim()}")
if tuple(mask.shape) != tuple(tokens.shape[:3]):
raise ValueError(f"memory stream mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}")
gate_tensor = self._expand_memory_stream_gate(stream_gate, tokens)
type_ids = torch.full((tokens.shape[2],), int(type_idx), device=tokens.device, dtype=torch.long)
return tokens, mask.to(device=tokens.device, dtype=torch.bool), gate_tensor, type_ids
def _expand_memory_stream_gate(self, stream_gate, tokens):
batch_size, num_frames, num_tokens = tokens.shape[:3]
if stream_gate is None:
return torch.ones(tokens.shape[:3], device=tokens.device, dtype=tokens.dtype)
if not torch.is_tensor(stream_gate):
return torch.full(tokens.shape[:3], float(stream_gate), device=tokens.device, dtype=tokens.dtype)
gate_tensor = stream_gate.to(device=tokens.device, dtype=tokens.dtype)
if gate_tensor.dim() == 0:
return gate_tensor.view(1, 1, 1).expand(batch_size, num_frames, num_tokens)
if gate_tensor.dim() == 1:
if gate_tensor.numel() != batch_size:
raise ValueError(f"rank-1 memory gate must have B={batch_size} values, got {tuple(gate_tensor.shape)}")
return gate_tensor.view(batch_size, 1, 1).expand(batch_size, num_frames, num_tokens)
if gate_tensor.dim() == 2:
if tuple(gate_tensor.shape) == (batch_size, num_frames):
return gate_tensor[:, :, None].expand(batch_size, num_frames, num_tokens)
if tuple(gate_tensor.shape) == (batch_size, num_tokens):
return gate_tensor[:, None, :].expand(batch_size, num_frames, num_tokens)
raise ValueError(
f"rank-2 memory gate must have shape (B,T) or (B,M), got {tuple(gate_tensor.shape)}"
)
if gate_tensor.dim() == 3:
if tuple(gate_tensor.shape) == (batch_size, num_frames, 1):
return gate_tensor.expand(batch_size, num_frames, num_tokens)
if tuple(gate_tensor.shape) == (batch_size, num_frames, num_tokens):
return gate_tensor
raise ValueError(
f"rank-3 memory gate must have shape (B,T,1) or (B,T,M), got {tuple(gate_tensor.shape)}"
)
raise ValueError(f"memory gate rank must be <=3, got rank {gate_tensor.dim()}")
def _pack_typed_memory_streams(
self,
batch_size,
num_frames,
memory_tokens=None,
memory_token_mask=None,
memory_dynamic_tokens=None,
memory_dynamic_mask=None,
memory_retrieval_tokens=None,
memory_retrieval_mask=None,
memory_anchor_gate=None,
memory_dynamic_gate=None,
memory_retrieval_gate=None,
):
streams = []
for tokens, mask, stream_gate, type_idx in (
(memory_tokens, memory_token_mask, memory_anchor_gate, MEMORY_TYPE_ANCHOR),
(memory_dynamic_tokens, memory_dynamic_mask, memory_dynamic_gate, MEMORY_TYPE_DYNAMIC),
(memory_retrieval_tokens, memory_retrieval_mask, memory_retrieval_gate, MEMORY_TYPE_REVISIT),
):
expanded = self._expand_memory_stream(tokens, mask, stream_gate, type_idx, batch_size, num_frames)
if expanded is not None:
streams.append(expanded)
if not streams:
return None
packed_tokens = torch.cat([item[0] for item in streams], dim=2)
packed_mask = torch.cat([item[1] for item in streams], dim=2)
packed_gate = torch.cat([item[2] for item in streams], dim=2)
packed_type_ids = torch.cat([item[3] for item in streams], dim=0)
valid_gate = packed_gate.masked_fill(~packed_mask, 0)
residual_gate = valid_gate.max(dim=2).values
return packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate
def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False,
pose_cond=None, mode="training", c_action_cond=None, reference_length=None,
memory_tokens=None, memory_token_mask=None, memory_dynamic_tokens=None, memory_dynamic_mask=None,
memory_retrieval_tokens=None, memory_retrieval_mask=None, memory_anchor_gate=None,
memory_dynamic_gate=None, memory_retrieval_gate=None):
B, T, H, W, D = x.shape
# spatial block
s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
# temporal block
if c_action_cond is not None:
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1)
else:
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
if self.ref_mode == 'sequential':
x = x_t
if self.use_memory_token_cross_attention:
memory_base = x
packed_memory = self._pack_typed_memory_streams(
B,
T,
memory_tokens=memory_tokens,
memory_token_mask=memory_token_mask,
memory_dynamic_tokens=memory_dynamic_tokens,
memory_dynamic_mask=memory_dynamic_mask,
memory_retrieval_tokens=memory_retrieval_tokens,
memory_retrieval_mask=memory_retrieval_mask,
memory_anchor_gate=memory_anchor_gate,
memory_dynamic_gate=memory_dynamic_gate,
memory_retrieval_gate=memory_retrieval_gate,
)
if packed_memory is not None:
packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate = packed_memory
x = self.memory_token_cross_attn(
memory_base,
c,
packed_tokens,
packed_mask,
residual_gate=residual_gate,
memory_type_ids=packed_type_ids,
memory_token_gate=packed_gate,
)
if self.ref_mode == 'parallel':
x = x_t + self.parallel_map(x)
return x
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_h=18,
input_w=32,
patch_size=2,
in_channels=16,
hidden_size=1024,
depth=12,
num_heads=16,
mlp_ratio=4.0,
action_cond_dim=25,
max_frames=32,
reference_length=8,
memory_token_cross_attention=False,
memory_cross_attn_layers=None,
ref_mode='sequential'
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.max_frames = max_frames
self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
self.t_embedder = TimestepEmbedder(hidden_size)
self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity()
if memory_cross_attn_layers is None:
memory_cross_attn_layer_set = None
else:
memory_cross_attn_layer_set = {int(layer_idx) for layer_idx in memory_cross_attn_layers}
invalid_layers = sorted(
layer_idx for layer_idx in memory_cross_attn_layer_set if layer_idx < 0 or layer_idx >= depth
)
if invalid_layers:
raise ValueError(
f"memory_cross_attn_layers contains invalid indices {invalid_layers} for depth={depth}"
)
self.blocks = nn.ModuleList(
[
SpatioTemporalDiTBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
is_causal=True,
reference_length=reference_length,
spatial_rotary_emb=self.spatial_rotary_emb,
temporal_rotary_emb=self.temporal_rotary_emb,
use_memory_token_cross_attention=memory_token_cross_attention
and (memory_cross_attn_layer_set is None or block_idx in memory_cross_attn_layer_set),
ref_mode=ref_mode
)
for block_idx in range(depth)
]
)
self.memory_token_cross_attention = memory_token_cross_attention
self.memory_cross_attn_layers = (
None if memory_cross_attn_layer_set is None else tuple(sorted(memory_cross_attn_layer_set))
)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.memory_token_cross_attention:
for block in self.blocks:
memory_adapter = getattr(block, "memory_token_cross_attn", None)
if memory_adapter is not None:
memory_adapter.reset_identity_init()
def memory_adapter_delta_diagnostics(self):
diagnostics = {}
ratios = []
type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES}
shared_type_gate_values = []
for block in self.blocks:
adapter = getattr(block, "memory_token_cross_attn", None)
if adapter is None:
continue
ratio = getattr(adapter, "last_delta_ratio", None)
if ratio is not None:
ratios.append(torch.as_tensor(ratio).detach().float())
type_gate = getattr(adapter, "last_type_gate_mean", None)
if type_gate is not None:
shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float())
for type_name in MEMORY_TYPE_NAMES:
value = getattr(adapter, f"last_type_gate_{type_name}_mean", None)
if value is not None:
type_gate_values[type_name].append(torch.as_tensor(value).detach().float())
if ratios:
values = torch.stack(ratios)
diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item())
diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item())
if shared_type_gate_values:
values = torch.stack(shared_type_gate_values)
diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item())
for type_name, values_list in type_gate_values.items():
if values_list:
values = torch.stack(values_list)
diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item())
return diagnostics
def unpatchify(self, x):
"""
x: (N, H, W, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = x.shape[1]
w = x.shape[2]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def forward(
self,
x,
t,
action_cond=None,
pose_cond=None,
current_frame=None,
mode=None,
reference_length=None,
frame_idx=None,
memory_tokens=None,
memory_token_mask=None,
memory_dynamic_tokens=None,
memory_dynamic_mask=None,
memory_retrieval_tokens=None,
memory_retrieval_mask=None,
memory_anchor_gate=None,
memory_dynamic_gate=None,
memory_retrieval_gate=None,
):
"""
Forward pass of DiT.
x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (B, T,) tensor of diffusion timesteps
"""
B, T, C, H, W = x.shape
# add spatial embeddings
x = rearrange(x, "b t c h w -> (b t) c h w")
x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
# restore shape
x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
# embed noise steps
t = rearrange(t, "b t -> (b t)")
c_t = self.t_embedder(t) # (N, D)
c = c_t.clone()
c = rearrange(c, "(b t) d -> b t d", t=T)
if torch.is_tensor(action_cond):
c_action_cond = c + self.external_cond(action_cond)
else:
c_action_cond = None
for i, block in enumerate(self.blocks):
x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)),
mode=mode, c_action_cond=c_action_cond, reference_length=reference_length,
memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
memory_dynamic_tokens=memory_dynamic_tokens, memory_dynamic_mask=memory_dynamic_mask,
memory_retrieval_tokens=memory_retrieval_tokens, memory_retrieval_mask=memory_retrieval_mask,
memory_anchor_gate=memory_anchor_gate, memory_dynamic_gate=memory_dynamic_gate,
memory_retrieval_gate=memory_retrieval_gate) # (N, T, H, W, D)
x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
# unpatchify
x = rearrange(x, "b t h w d -> (b t) h w d")
x = self.unpatchify(x) # (N, out_channels, H, W)
x = rearrange(x, "(b t) c h w -> b t c h w", t=T)
return x
def DiT_S_2(
action_cond_dim,
reference_length,
ref_mode,
memory_token_cross_attention=False,
memory_cross_attn_layers=None,
):
return DiT(
patch_size=2,
hidden_size=1024,
depth=16,
num_heads=16,
action_cond_dim=action_cond_dim,
reference_length=reference_length,
memory_token_cross_attention=memory_token_cross_attention,
memory_cross_attn_layers=memory_cross_attn_layers,
ref_mode=ref_mode
)
DiT_models = {"DiT-S/2": DiT_S_2}
|