File size: 43,033 Bytes
3d1c0e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 |
import math
from typing import Optional
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
# Global debug flag - set to False to disable debug prints
DEBUG_TRANSFORMER = False
# from .attention import flash_attention
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
'flash_attention',
'attention',
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
# Ensure position is on CPU for float64 computation to avoid CUDA issues
# Convert to float64 for precision, then move back to original device
device = position.device
position = position.to(torch.float64)
# calculation
# Create range tensor on same device as position
arange_tensor = torch.arange(half, dtype=torch.float64, device=device)
sinusoid = torch.outer(
position, torch.pow(10000, -arange_tensor.div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@torch.amp.autocast('cuda', enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@torch.amp.autocast('cuda', enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# Save original dtype to restore it later
original_dtype = x.dtype
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
# Convert back to original dtype before concatenating
x_i = x_i.to(dtype=original_dtype)
# Handle the remaining part of the sequence
x_remaining = x[i, seq_len:]
if x_remaining.numel() > 0:
x_i = torch.cat([x_i, x_remaining])
else:
x_i = x_i
# append to collection
output.append(x_i)
# Stack and ensure dtype matches original input
return torch.stack(output).to(dtype=original_dtype)
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
# Ensure weight dtype matches input dtype
return self._norm(x.float()).type_as(x) * self.weight.type_as(x)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
# Convert to float32 for numerical stability, ensuring weights match input dtype
original_dtype = x.dtype
x_float = x.float()
if self.elementwise_affine:
weight_float = self.weight.float() if self.weight is not None else None
bias_float = self.bias.float() if self.bias is not None else None
# Use torch.nn.functional.layer_norm directly with converted weights
result = torch.nn.functional.layer_norm(x_float, self.normalized_shape, weight_float, bias_float, self.eps)
else:
result = super().forward(x_float)
return result.to(dtype=original_dtype)
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
# Save input dtype to ensure output matches
input_dtype = x.dtype
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# Ensure output dtype matches input dtype (in case rope_apply or flash_attention changed it)
x = x.to(dtype=input_dtype)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# Save input dtype to ensure output matches
input_dtype = x.dtype
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# Ensure output dtype matches input dtype
x = x.to(dtype=input_dtype)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# Convert e to float32 for modulation computation (modulation expects float32)
e_float32 = e.to(dtype=torch.float32) if e.dtype != torch.float32 else e
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e_float32).chunk(6, dim=2)
assert e[0].dtype == torch.float32
# self-attention
# Ensure input dtype matches model weights (convert e to match x's dtype)
x_dtype = x.dtype
e_0 = e[0].squeeze(2).to(dtype=x_dtype)
e_1 = e[1].squeeze(2).to(dtype=x_dtype)
e_2 = e[2].squeeze(2).to(dtype=x_dtype)
attn_input = self.norm1(x) * (1 + e_1) + e_0
y = self.self_attn(attn_input, seq_lens, grid_sizes, freqs)
# Ensure dtype consistency: y and e_2 should match x's dtype
x = x + (y * e_2).to(dtype=x_dtype)
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
# Ensure dtype consistency for FFN input
x_dtype = x.dtype
e_3 = e[3].squeeze(2).to(dtype=x_dtype)
e_4 = e[4].squeeze(2).to(dtype=x_dtype)
e_5 = e[5].squeeze(2).to(dtype=x_dtype)
ffn_input = self.norm2(x) * (1 + e_4) + e_3
y = self.ffn(ffn_input)
# Ensure dtype consistency: y and e_5 should match x's dtype
x = x + (y * e_5).to(dtype=x_dtype)
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
# Convert e to float32 for modulation computation (modulation expects float32)
e_float32 = e.to(dtype=torch.float32) if e.dtype != torch.float32 else e
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e_float32.unsqueeze(2)).chunk(2, dim=2)
# Ensure dtype consistency: convert e to match x's dtype
x_dtype = x.dtype
e_0 = e[0].squeeze(2).to(dtype=x_dtype)
e_1 = e[1].squeeze(2).to(dtype=x_dtype)
head_input = self.norm(x) * (1 + e_1) + e_0
x = self.head(head_input)
return x
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = [
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
]
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ['t2v', 'i2v', 'ti2v', 's2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps) for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
# initialize weights
self.init_weights()
def forward(
self,
x,
t,
context,
seq_len,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
# Ensure input dtype matches patch_embedding weight dtype
patch_weight_dtype = self.patch_embedding.weight.dtype
x = [self.patch_embedding(u.unsqueeze(0).to(dtype=patch_weight_dtype)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# Keep e and e0 as float32 for modulation computation
# They will be converted to x.dtype inside WanAttentionBlock.forward and Head.forward when needed
# context
context_lens = None
# Ensure context input dtype matches text_embedding weight dtype
text_weight_dtype = self.text_embedding[0].weight.dtype
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]).to(dtype=text_weight_dtype))
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
class WanDiscreteVideoTransformer(ModelMixin, ConfigMixin):
r"""
Wrapper around :class:`WanModel` that makes it usable as a **discrete video diffusion backbone**.
The goals of this wrapper are:
- keep the inner :class:`WanModel` architecture and parameter names intact so that Wan-1.3B
weights can later be loaded directly into ``self.backbone``;
- expose a simpler interface that takes **discrete codebook indices** (from a 2D VQ-VAE on
pseudo-video) and returns **logits over the codebook** for each spatio‑temporal position.
Notes
-----
- This class does **not** try to be drop‑in compatible with Meissonic's 2D ``Transformer2DModel``.
It is a parallel, video‑oriented path that still follows the same *discrete diffusion* principle:
predict per‑token logits given masked tokens + text.
- Pseudo‑video is represented as a 4D integer tensor ``[B, F, H, W]`` of codebook indices.
How to get these tokens from the current 2D VQ-VAE (e.g. per‑frame encoding & stacking)
is left to the higher‑level training / pipeline code.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# discrete codebook settings
codebook_size: int,
vocab_size: int,
# video layout
num_frames: int,
height: int,
width: int,
# Wan backbone hyper‑parameters (mirrors WanModel.__init__)
model_type: str = 't2v',
patch_size: tuple = (1, 2, 2),
text_len: int = 512,
in_dim: int = 16,
dim: int = 2048,
ffn_dim: int = 8192,
freq_dim: int = 256,
text_dim: int = 4096,
out_dim: int = 16,
num_heads: int = 16,
num_layers: int = 32,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
# save a minimal set of attributes useful for downstream tooling
self.codebook_size = codebook_size
self.vocab_size = vocab_size
self.num_frames = num_frames
self.height = height
self.width = width
# 1) backbone: keep WanModel intact for future weight loading
self.backbone = WanModel(
model_type=model_type,
patch_size=patch_size,
text_len=text_len,
in_dim=in_dim,
dim=dim,
ffn_dim=ffn_dim,
freq_dim=freq_dim,
text_dim=text_dim,
out_dim=out_dim,
num_heads=num_heads,
num_layers=num_layers,
window_size=window_size,
qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm,
eps=eps,
)
# 2) discrete token embedding -> continuous video volume
#
# Input: tokens [B, F, H, W] with values in [0, vocab_size) where:
# - [0, codebook_size-1] = actual Cosmos codes (direct mapping, no shift)
# - codebook_size = mask_token_id (reserved for masking)
# Output: list of length B with tensors [in_dim, F, H, W]
#
# We keep this outside the backbone so that loading official Wan 1.3B weights
# into self.backbone will still work without clashes.
# Note: vocab_size = codebook_size + 1 to accommodate mask_token_id = codebook_size
self.token_embedding = nn.Embedding(vocab_size, in_dim)
# 3) projection from continuous video output -> logits over codebook
#
# Backbone output: list of B tensors [out_dim, F, H', W']
# We map it with a 3D 1x1x1 conv to [vocab_size, F, H', W'].
# Note: vocab_size = codebook_size + 1, where codebook_size is reserved for mask_token_id
self.logits_head = nn.Conv3d(out_dim, vocab_size, kernel_size=1)
# Gradient checkpointing support
self.gradient_checkpointing = False
def _tokens_to_video(self, tokens: torch.LongTensor) -> list:
r"""
Convert discrete tokens ``[B, F, H, W]`` into a list of length ``B`` where each element
is a dense video tensor ``[in_dim, F, H, W]`` suitable for :class:`WanModel`.
Note:
This method now supports dynamic input dimensions. The num_frames, height, width
stored in config are used as defaults/for seq_len calculation, but inputs can
have different dimensions as long as they're valid.
"""
assert tokens.dim() == 4, f"expected [B, F, H, W] tokens, got {tokens.shape}"
# Dynamic dimensions - no strict dimension checks, WanModel handles variable sizes
# [B, F, H, W, in_dim]
# Ensure output dtype matches token_embedding weight dtype
x = self.token_embedding(tokens)
# Ensure dtype matches model's expected dtype (usually bfloat16 for mixed precision)
token_embedding_dtype = self.token_embedding.weight.dtype
x = x.to(dtype=token_embedding_dtype)
# [B, in_dim, F, H, W]
x = x.permute(0, 4, 1, 2, 3).contiguous()
# WanModel expects a list of [C_in, F, H, W]
return [x_i for x_i in x]
def _text_to_list(self, encoder_hidden_states: torch.Tensor) -> list:
r"""
Convert batched text embeddings ``[B, L, C]`` into the list-of-tensors format
expected by :class:`WanModel`.
"""
assert encoder_hidden_states.dim() == 3, (
f"expected encoder_hidden_states [B, L, C], got {encoder_hidden_states.shape}")
return [e for e in encoder_hidden_states]
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
"""Set gradient checkpointing for the module."""
self.gradient_checkpointing = enable
def forward(
self,
tokens: torch.LongTensor,
timesteps: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
y: Optional[list] = None,
) -> torch.FloatTensor:
r"""
Forward pass of the **discrete video transformer**.
Args:
tokens (`torch.LongTensor` of shape `[B, F, H, W]`):
Discrete codebook indices (e.g. from a 2D VQ-VAE applied frame‑wise).
timesteps (`torch.LongTensor` of shape `[B]` or `[B, F * H * W]`):
Diffusion timestep(s), following the same semantics as Meissonic's scalar timesteps.
encoder_hidden_states (`torch.FloatTensor` of shape `[B, L, C_text]`):
Text embeddings (e.g. from CLIP). Each sample corresponds to one video.
y (`Optional[list]`):
Optional conditional video list passed to the underlying :class:`WanModel`
for i2v / ti2v / s2v variants. For now this is surfaced as a raw passthrough
and can be left as ``None`` for pure text‑to‑video.
Returns:
`torch.FloatTensor`:
Logits over the codebook of shape `[B, codebook_size, F, H_out, W_out]`, where
`(H_out, W_out)` depend on the Wan patch configuration. For the default
`patch_size=(1, 2, 2)` and input ``H=W=height``, we have
``H_out = height // 2`` and ``W_out = width // 2``.
"""
device = tokens.device
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] Input: tokens.shape={tokens.shape}, encoder_hidden_states.shape={encoder_hidden_states.shape}, timesteps.shape={timesteps.shape}")
x_list = self._tokens_to_video(tokens)
context_list = self._text_to_list(encoder_hidden_states)
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] After conversion: len(x_list)={len(x_list)}, len(context_list)={len(context_list)}")
if len(x_list) > 0:
print(f"[DEBUG-transformer] x_list[0].shape={x_list[0].shape}")
if len(context_list) > 0:
print(f"[DEBUG-transformer] context_list[0].shape={context_list[0].shape}")
# Calculate seq_len from actual input dimensions (supports dynamic sizes)
# tokens: [B, F, H, W] -> after patchification: seq_len = F * (H/p_h) * (W/p_w)
_, f_in, h_in, w_in = tokens.shape
h_patch = h_in // self.backbone.patch_size[1]
w_patch = w_in // self.backbone.patch_size[2]
seq_len = f_in * h_patch * w_patch
# Prepare timesteps in the exact shape WanModel.forward expects.
# Its current implementation assumes `t` is either [B, seq_len] or will be
# expanded from 1D; the 1D branch is slightly buggy for non-singleton dims,
# so we always give it a [B, seq_len] tensor here.
if timesteps.dim() == 1:
# [B] -> [B, 1] -> [B, seq_len] (broadcast along sequence)
t_model = timesteps.to(device).unsqueeze(1).expand(-1, seq_len)
elif timesteps.dim() == 2:
assert timesteps.size(1) == seq_len, (
f"Expected timesteps second dim == seq_len ({seq_len}), "
f"but got {timesteps.size(1)}"
)
t_model = timesteps.to(device)
else:
raise ValueError(
f"Unsupported timesteps shape {timesteps.shape}; "
"expected [B] or [B, seq_len]"
)
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] t_model.shape={t_model.shape}")
# WanModel.forward expects:
# x: List[Tensor [C_in, F, H, W]]
# t: Tensor [B] or [B, seq_len]
# context: List[Tensor [L, C_text]]
# seq_len: int
# y: Optional[List[Tensor]]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
# Unpack inputs: x_list, t, context_list, seq_len, y
x_in, t_in, context_in, seq_len_in, y_in = inputs
return module(x=x_in, t=t_in, context=context_in, seq_len=seq_len_in, y=y_in)
return custom_forward
# Use gradient checkpointing for the backbone
ckpt_kwargs = {"use_reentrant": False}
out_list = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.backbone),
x_list,
t_model,
context_list,
seq_len,
y,
**ckpt_kwargs,
)
else:
out_list = self.backbone(
x=x_list,
t=t_model,
context=context_list,
seq_len=seq_len,
y=y,
)
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] After backbone: len(out_list)={len(out_list)}")
if len(out_list) > 0:
print(f"[DEBUG-transformer] out_list[0].shape={out_list[0].shape}")
# out_list: length B, each [C_out, F, H_out, W_out]
vids = torch.stack(out_list, dim=0) # [B, C_out, F, H_out, W_out]
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] After stack: vids.shape={vids.shape}")
# Ensure vids dtype matches logits_head weight dtype
vids = vids.to(dtype=self.logits_head.weight.dtype)
logits = self.logits_head(vids) # [B, vocab_size, F, H_out, W_out] where vocab_size = codebook_size + 1
if DEBUG_TRANSFORMER:
print(f"[DEBUG-transformer] Final logits.shape={logits.shape}")
return logits
# def _available_device():
# return "cuda" if torch.cuda.is_available() else "cpu"
# def test_wan_discrete_video_transformer_forward_and_shapes():
# """
# Basic smoke test:
# - build a tiny WanDiscreteVideoTransformer
# - run a forward pass with random pseudo-video tokens + random text
# - check output shapes, parameter count and (if CUDA present) memory usage
# """
# device = _available_device()
# # small config to keep the test lightweight
# codebook_size = 128
# vocab_size = codebook_size + 1 # reserve one for mask if needed later
# num_frames = 2
# height = 16
# width = 16
# model = WanDiscreteVideoTransformer(
# codebook_size=codebook_size,
# vocab_size=vocab_size,
# num_frames=num_frames,
# height=height,
# width=width,
# # shrink Wan backbone for the unit test
# in_dim=32,
# dim=64,
# ffn_dim=128,
# freq_dim=32,
# text_dim=64,
# out_dim=32,
# num_heads=4,
# num_layers=2,
# ).to(device)
# model.eval()
# batch_size = 2
# # pseudo-video tokens from 2D VQ-VAE on frames: [B, F, H, W]
# tokens = torch.randint(
# low=0,
# high=codebook_size,
# size=(batch_size, num_frames, height, width),
# dtype=torch.long,
# device=device,
# )
# # text: [B, L, C_text]
# text_seq_len = 8
# encoder_hidden_states = torch.randn(
# batch_size, text_seq_len, model.backbone.text_dim, device=device
# )
# # timesteps: [B]
# timesteps = torch.randint(
# low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device
# )
# # track memory if CUDA is available
# if device == "cuda":
# torch.cuda.reset_peak_memory_stats()
# mem_before = torch.cuda.memory_allocated()
# else:
# mem_before = 0
# with torch.no_grad():
# logits = model(
# tokens=tokens,
# timesteps=timesteps,
# encoder_hidden_states=encoder_hidden_states,
# y=None,
# )
# if device == "cuda":
# mem_after = torch.cuda.memory_allocated()
# peak_mem = torch.cuda.max_memory_allocated()
# else:
# mem_after = mem_before
# peak_mem = mem_before
# # logits: [B, codebook_size, F, H_out, W_out]
# assert logits.shape[0] == batch_size
# assert logits.shape[1] == codebook_size
# assert logits.shape[2] == num_frames
# # WanModel returns unpatchified videos, so spatial size matches the input grid.
# h_out = height
# w_out = width
# assert logits.shape[3] == h_out
# assert logits.shape[4] == w_out
# # parameter count sanity check (just ensure it's > 0 and finite)
# num_params = sum(p.numel() for p in model.parameters())
# assert num_params > 0
# assert math.isfinite(float(num_params))
# # memory sanity check (on CUDA the forward pass should allocate > 0 bytes)
# if device == "cuda":
# assert peak_mem >= mem_after >= mem_before
# import torch
# from safetensors import safe_open
# # from src.transformer_video import WanDiscreteVideoTransformer
# ckpt_path = "/mnt/Meissonic/model/diffusion_pytorch_model.safetensors"
# # 1) 按你想匹配 wan2.1 的超参实例化(这里写一份常用配置,务必与 ckpt 对齐)
# model = WanDiscreteVideoTransformer(
# codebook_size=128, # 离散侧自定义
# vocab_size=129,
# num_frames=2,
# height=16,
# width=16,
# # Wan backbone 超参需与 ckpt 完全一致
# model_type="t2v",
# patch_size=(1, 2, 2),
# in_dim=16,
# dim=1536,
# ffn_dim=8960,
# freq_dim=256,
# text_dim=4096,
# out_dim=16,
# num_heads=12,
# num_layers=30,
# window_size=(-1, -1),
# qk_norm=True,
# cross_attn_norm=True,
# eps=1e-6,
# )
# # 2) 读取 safetensors
# state_dict = {}
# with safe_open(ckpt_path, framework="pt", device="cpu") as f:
# for k in f.keys():
# state_dict[k] = f.get_tensor(k)
# # 3) 尝试加载到 backbone(不碰 token_embedding/logits_head)
# missing, unexpected = model.backbone.load_state_dict(state_dict, strict=False)
# print("Missing keys:", missing[:50], "... total", len(missing))
# print("Unexpected keys:", unexpected[:50], "... total", len(unexpected))
# print("Backbone params (M):", sum(p.numel() for p in model.backbone.parameters()) / 1e6)
# print("Params (M):", sum(p.numel() for p in model.parameters()) / 1e6)
# # if __name__ == '__main__':
# # # test_wan_discrete_video_transformer_forward_and_shapes()
# # print('WanDiscreteVideoTransformer forward pass test: PASSED')
|