File size: 15,030 Bytes
bc8c4af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 | import torch
import torch.nn as nn
from .wan_video_dit import DiTBlock, CrossAttention, flash_attention
from ..core.gradient import gradient_checkpoint_forward
class _OffloadToCPU(torch.autograd.Function):
"""Move tensor to CPU in forward, move gradient to GPU in backward."""
@staticmethod
def forward(ctx, tensor):
ctx.gpu_device = tensor.device
return tensor.detach().cpu().requires_grad_(True)
@staticmethod
def backward(ctx, grad_output):
return grad_output.to(ctx.gpu_device)
class _RestoreToGPU(torch.autograd.Function):
"""Move CPU tensor back to GPU in forward, move gradient to CPU in backward."""
@staticmethod
def forward(ctx, tensor, device):
ctx.save_for_backward(torch.empty(0))
ctx._device_str = str(device)
return tensor.to(device)
@staticmethod
def backward(ctx, grad_output):
return grad_output.cpu(), None
def tokenize_target_text(text, max_len=64, vocab_size=8192):
"""Convert target text string to character-level token IDs.
Uses modular hashing of Unicode code points to map any character
(ASCII, CJK, Cyrillic, math symbols, etc.) into a fixed vocabulary.
Token 0 is reserved for padding.
"""
ids = []
for ch in text[:max_len]:
ids.append(ord(ch) % (vocab_size - 1) + 1)
ids += [0] * (max_len - len(ids))
return ids
class ConditionCrossAttention(nn.Module):
"""Cross-attention for injecting condition features into VACE blocks.
Used for both glyph visual features (v2) and character-level text
tokens (v3). Each spatial position in the VACE hidden states queries
the condition tokens to determine what should be rendered at that location.
Zero-initialized output projection ensures the model starts from
pretrained VACE behavior and gradually learns to use the condition.
"""
def __init__(self, dim, num_heads, eps=1e-6):
super().__init__()
self.num_heads = num_heads
self.norm = nn.LayerNorm(dim, eps=eps)
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
# Zero-init output projection so condition attention starts as no-op
nn.init.zeros_(self.o.weight)
nn.init.zeros_(self.o.bias)
def forward(self, x, condition_tokens):
"""
Args:
x: VACE hidden states (B, seq_len, dim)
condition_tokens: condition features (B, num_tokens, dim)
Returns:
x + condition_attn_output (B, seq_len, dim)
"""
residual = x
x = self.norm(x)
q = self.q(x)
k = self.k(condition_tokens)
v = self.v(condition_tokens)
out = flash_attention(q, k, v, self.num_heads)
return residual + self.o(out)
# Backward compatibility alias
GlyphCrossAttention = ConditionCrossAttention
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = torch.nn.Linear(self.dim, self.dim)
self.after_proj = torch.nn.Linear(self.dim, self.dim)
def forward(self, c, x, context, t_mod, freqs, condition_tokens=None):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, context, t_mod, freqs)
# Condition cross-attention (injected after text cross-attn + FFN)
# Works for both glyph tokens (v2) and character tokens (v3)
if condition_tokens is not None and hasattr(self, 'condition_cross_attn'):
c = self.condition_cross_attn(c, condition_tokens)
c_skip = self.after_proj(c)
return c_skip, c
class GlyphEncoder(nn.Module):
"""Lightweight encoder that compresses glyph latents into tokens
for cross-attention in VACE blocks.
glyph_latent (16ch, from VAE) β Conv3D patch embed β spatial pooling
β compact token sequence that each VACE block attends to.
"""
def __init__(self, in_channels=16, dim=1536, num_tokens=64, patch_size=(1, 2, 2)):
super().__init__()
self.num_tokens = num_tokens
# Patch embedding (same architecture as vace_patch_embedding)
self.patch_embed = nn.Conv3d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
# Compress spatial sequence to fixed number of tokens via cross-attention pooling
self.query_tokens = nn.Parameter(torch.randn(1, num_tokens, dim) * 0.02)
self.pool_attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
self.pool_norm = nn.LayerNorm(dim)
# Output projection (zero-init for safe initialization)
self.out_proj = nn.Linear(dim, dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, glyph_latent):
"""
Args:
glyph_latent: (B, 16, T, H, W) from VAE-encoded glyph video
Returns:
glyph_tokens: (B, num_tokens, dim) compressed glyph features
"""
# Patch embed: (B, dim, T, H/2, W/2)
x = self.patch_embed(glyph_latent)
B = x.shape[0]
# Flatten spatial: (B, dim, N) β (B, N, dim)
x = x.flatten(2).transpose(1, 2)
# Cross-attention pooling: compress N spatial tokens β num_tokens
queries = self.query_tokens.expand(B, -1, -1).to(dtype=x.dtype, device=x.device)
x_normed = self.pool_norm(x)
pooled, _ = self.pool_attn(queries, x_normed, x_normed)
return self.out_proj(pooled)
class TargetTextEncoder(nn.Module):
"""Character-level encoder for target text strings.
Encodes the target text (e.g., "NAD", "CARLIFE", "ε") at the character
level using learned embeddings and a small Transformer. This provides
character-precise identity information that T5's subword tokenization
cannot guarantee.
The output tokens are used as K/V in ConditionCrossAttention within
each VACE block, allowing each spatial position to query which character
it should render.
"""
def __init__(self, vocab_size=8192, max_len=64, dim=1536,
num_layers=2, num_heads=8):
super().__init__()
self.vocab_size = vocab_size
self.max_len = max_len
self.char_embed = nn.Embedding(vocab_size, dim, padding_idx=0)
self.pos_embed = nn.Embedding(max_len, dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim, nhead=num_heads, dim_feedforward=dim * 4,
batch_first=True, norm_first=True, activation='gelu',
)
self.transformer = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers,
)
# Zero-init output projection for safe initialization
self.out_proj = nn.Linear(dim, dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, token_ids):
"""
Args:
token_ids: (B, max_len) long tensor of character IDs (0 = PAD)
Returns:
text_tokens: (B, max_len, dim)
"""
B, L = token_ids.shape
positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
x = self.char_embed(token_ids) + self.pos_embed(positions)
# Padding mask: True means ignore this position
pad_mask = (token_ids == 0)
x = self.transformer(x, src_key_padding_mask=pad_mask)
return self.out_proj(x)
class VaceWanModel(torch.nn.Module):
def __init__(
self,
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
vace_in_dim=96,
glyph_channels=0,
glyph_num_tokens=64,
use_target_text_encoder=False,
text_vocab_size=8192,
text_max_len=64,
text_num_layers=2,
patch_size=(1, 2, 2),
has_image_input=False,
dim=1536,
num_heads=12,
ffn_dim=8960,
eps=1e-6,
):
super().__init__()
self.vace_layers = vace_layers
self.vace_in_dim = vace_in_dim
self.glyph_channels = glyph_channels
self.use_target_text_encoder = use_target_text_encoder
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# VACE blocks
self.vace_blocks = torch.nn.ModuleList([
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
for i in self.vace_layers
])
# VACE patch embedding (original 96 channels, no glyph concat)
self.vace_patch_embedding = torch.nn.Conv3d(
vace_in_dim, dim, kernel_size=patch_size, stride=patch_size
)
# Glyph pathway (v2: glyph video β GlyphEncoder β cross-attention)
if glyph_channels > 0:
self.glyph_encoder = GlyphEncoder(
in_channels=glyph_channels,
dim=dim,
num_tokens=glyph_num_tokens,
patch_size=patch_size,
)
for block in self.vace_blocks:
block.condition_cross_attn = ConditionCrossAttention(dim, num_heads, eps)
# Target text pathway (v3: text string β TargetTextEncoder β cross-attention)
if use_target_text_encoder:
self.target_text_encoder = TargetTextEncoder(
vocab_size=text_vocab_size,
max_len=text_max_len,
dim=dim,
num_layers=text_num_layers,
num_heads=num_heads,
)
for block in self.vace_blocks:
block.condition_cross_attn = ConditionCrossAttention(dim, num_heads, eps)
def _has_new_modules(self):
"""Check if this model has modules not present in pretrained checkpoints."""
return self.glyph_channels > 0 or self.use_target_text_encoder
def load_state_dict(self, state_dict, strict=True, assign=False):
if self._has_new_modules():
# New modules won't be in pretrained checkpoints.
# First, materialize any meta-tensor parameters so they can be assigned.
for name, param in self.named_parameters():
if param.is_meta:
materialized = torch.zeros(param.shape, dtype=param.dtype, device='cpu')
parts = name.split('.')
module = self
for p in parts[:-1]:
module = getattr(module, p)
setattr(module, parts[-1], torch.nn.Parameter(materialized))
result = super().load_state_dict(state_dict, strict=False, assign=assign)
# Re-initialize glyph modules (v2 mode)
if hasattr(self, 'glyph_encoder'):
for name, module in self.glyph_encoder.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv3d)):
if 'out_proj' not in name:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if hasattr(self.glyph_encoder, 'query_tokens'):
torch.nn.init.normal_(self.glyph_encoder.query_tokens, std=0.02)
# Re-initialize target text encoder modules (v3 mode)
if hasattr(self, 'target_text_encoder'):
for name, module in self.target_text_encoder.named_modules():
if isinstance(module, (torch.nn.Linear,)):
if 'out_proj' not in name:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, torch.nn.Embedding):
torch.nn.init.normal_(module.weight, std=0.02)
if module.padding_idx is not None:
torch.nn.init.zeros_(module.weight[module.padding_idx])
# Re-initialize condition cross-attention modules
for block in self.vace_blocks:
if hasattr(block, 'condition_cross_attn'):
ca = block.condition_cross_attn
for name, module in ca.named_modules():
if isinstance(module, torch.nn.Linear):
if name == 'o':
torch.nn.init.zeros_(module.weight)
torch.nn.init.zeros_(module.bias)
else:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# Cast re-initialized modules to model dtype (xavier_uniform_ etc. produce float32)
target_dtype = next((p.dtype for p in self.parameters() if p.dtype != torch.float32), torch.bfloat16)
if hasattr(self, 'glyph_encoder'):
self.glyph_encoder.to(dtype=target_dtype)
if hasattr(self, 'target_text_encoder'):
self.target_text_encoder.to(dtype=target_dtype)
for block in self.vace_blocks:
if hasattr(block, 'condition_cross_attn'):
block.condition_cross_attn.to(dtype=target_dtype)
return result
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self, x, vace_context, context, t_mod, freqs,
glyph_latent=None,
target_text_ids=None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
dim=1) for u in c
])
# Encode condition tokens for cross-attention
condition_tokens = None
if glyph_latent is not None and hasattr(self, 'glyph_encoder'):
condition_tokens = self.glyph_encoder(glyph_latent)
elif target_text_ids is not None and hasattr(self, 'target_text_encoder'):
condition_tokens = self.target_text_encoder(target_text_ids)
hints = []
for block in self.vace_blocks:
result = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs, condition_tokens
)
c_skip, c = result
hints.append(c_skip)
return hints
|