daVinci-MagiHuman / inference /model /vae2_2 /vae2_2_module.py
ethanchern's picture
init
873b6ec
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024-2026 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
__all__ = ["Wan2_2_VAE"]
CACHE_T = 2
class ScatterFwdAllGatherBackwardOverlap(torch.autograd.Function):
@staticmethod
def forward(ctx, x, group, overlap_size):
"""
Forward pass: split input tensor along W; each rank processes its local
chunk including overlap regions.
Args:
x: Input tensor, shape [B, C, T, H, W]
group: Distributed communication group
overlap_size: Width of overlap region
"""
W = x.shape[4]
world_size = torch.distributed.get_world_size(group)
rank = torch.distributed.get_rank(group)
# Compute base chunk size
base_chunk_size = (W + world_size - 1) // world_size
# Compute chunk range for current rank
chunk_start = rank * base_chunk_size
chunk_end = min((rank + 1) * base_chunk_size, W)
# Extend range with overlap
overlap_start = max(0, chunk_start - overlap_size)
overlap_end = min(W, chunk_end + overlap_size)
# Slice local chunk
x_chunk = x[:, :, :, :, overlap_start:overlap_end].contiguous()
# Save metadata needed by backward
ctx.save_for_backward(torch.tensor([overlap_start, overlap_end, W], dtype=torch.long, device=x.device))
ctx.group = group
ctx.overlap_size = overlap_size
ctx.world_size = world_size
ctx.rank = rank
ctx.base_chunk_size = base_chunk_size
return x_chunk
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass: all-gather gradients from all ranks and trim overlap.
"""
# Restore saved forward metadata
overlap_start, overlap_end, W = ctx.saved_tensors[0]
overlap_start = overlap_start.item()
overlap_end = overlap_end.item()
W = W.item()
group = ctx.group
overlap_size = ctx.overlap_size
world_size = ctx.world_size
ctx.rank
base_chunk_size = ctx.base_chunk_size
# Collect gradients from all ranks via all_gather
grad_output = grad_output.contiguous()
B, C, T, H = grad_output.shape[:4]
grad_shapes = []
for r in range(world_size):
r_chunk_start = r * base_chunk_size
r_chunk_end = min((r + 1) * base_chunk_size, W)
r_overlap_start = max(0, r_chunk_start - overlap_size)
r_overlap_end = min(W, r_chunk_end + overlap_size)
# Compute gradient shape for each rank
chunk_width = r_overlap_end - r_overlap_start
grad_shapes.append((B, C, T, H, chunk_width))
grad_chunks = [
torch.zeros(grad_shape, device=grad_output.device, dtype=grad_output.dtype) for grad_shape in grad_shapes
]
torch.distributed.all_gather(grad_chunks, grad_output, group=group)
# Stitch gathered chunks into full gradient tensor
full_grad = torch.zeros(B, C, T, H, W, device=grad_output.device, dtype=grad_output.dtype)
# Place each rank's gradient chunk at the correct position
for r in range(world_size):
r_chunk_start = r * base_chunk_size
r_chunk_end = min((r + 1) * base_chunk_size, W)
r_overlap_start = max(0, r_chunk_start - overlap_size)
r_overlap_end = min(W, r_chunk_end + overlap_size)
# Position in full gradient
grad_start_in_full = r_overlap_start
grad_end_in_full = r_overlap_end
# Position inside gathered chunk
grad_start_in_chunk = 0
grad_end_in_chunk = r_overlap_end - r_overlap_start
# Handle left boundary for first rank
if r == 0:
grad_start_in_chunk = 0
grad_end_in_chunk = min(r_chunk_end + overlap_size, W) - r_overlap_start
# Handle right boundary for last rank
elif r == world_size - 1:
grad_start_in_chunk = max(0, r_chunk_start - overlap_size) - r_overlap_start
grad_end_in_chunk = r_overlap_end - r_overlap_start
# Accumulate into full gradient
full_grad[:, :, :, :, grad_start_in_full:grad_end_in_full] += grad_chunks[r][
:, :, :, :, grad_start_in_chunk:grad_end_in_chunk
]
return full_grad, None, None
def scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=0):
return ScatterFwdAllGatherBackwardOverlap.apply(x, group, overlap_size)
class AllGatherFwdScatterBackwardOverlap(torch.autograd.Function):
@staticmethod
def forward(ctx, x, group, overlap_size):
"""
Forward pass: each rank clips local input, then all-gathers clipped chunks.
Args:
x: Input tensor, shape [B, C, T, H, W], already local overlapped chunk per rank
group: Distributed communication group
overlap_size: Width of overlap region
"""
world_size = torch.distributed.get_world_size(group)
rank = torch.distributed.get_rank(group)
# Clip local input first (remove overlap area)
if rank == 0:
valid_start = 0
valid_end = x.shape[-1] - overlap_size
elif rank == world_size - 1:
valid_start = overlap_size
valid_end = x.shape[-1]
else:
valid_start = overlap_size
valid_end = x.shape[-1] - overlap_size
x_clipped = x[..., valid_start:valid_end].contiguous()
clipped_width = x_clipped.shape[-1]
# First all_gather: collect clipped widths across ranks
width_tensor = torch.tensor([clipped_width], dtype=torch.long, device=x.device)
all_widths = [torch.zeros_like(width_tensor) for _ in range(world_size)]
torch.distributed.all_gather(all_widths, width_tensor, group=group)
clipped_widths = [w.item() for w in all_widths]
# Second all_gather: collect clipped data across ranks
B, C, T, H = x_clipped.shape[:4]
x_clipped_chunks = [torch.zeros(B, C, T, H, w, device=x.device, dtype=x.dtype) for w in clipped_widths]
torch.distributed.all_gather(x_clipped_chunks, x_clipped, group=group)
full_x = torch.cat(x_clipped_chunks, dim=-1)
# Save metadata needed by backward
ctx.save_for_backward(torch.tensor([valid_start, valid_end], dtype=torch.long, device=x.device))
ctx.clipped_widths = clipped_widths
ctx.group = group
ctx.overlap_size = overlap_size
ctx.world_size = world_size
ctx.rank = rank
return full_x
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass: each rank restores gradients for its own partition only.
"""
# Restore saved forward metadata
valid_start, valid_end = ctx.saved_tensors[0]
valid_start = valid_start.item()
valid_end = valid_end.item()
clipped_widths = ctx.clipped_widths
ctx.group
overlap_size = ctx.overlap_size
world_size = ctx.world_size
rank = ctx.rank
# Compute current rank offset in full gradient
start_pos = sum(clipped_widths[:rank])
end_pos = start_pos + clipped_widths[rank]
# Extract only current rank gradient slice
grad_clipped = grad_output[:, :, :, :, start_pos:end_pos]
# Pad zeros to recover overlap area for current rank
if rank == 0:
# First rank: pad right
grad_full = F.pad(grad_clipped, (0, overlap_size))
elif rank == world_size - 1:
# Last rank: pad left
grad_full = F.pad(grad_clipped, (overlap_size, 0))
else:
# Middle rank: pad both sides
grad_full = F.pad(grad_clipped, (overlap_size, overlap_size))
return grad_full, None, None
def all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=0):
return AllGatherFwdScatterBackwardOverlap.apply(x, group, overlap_size)
def one_plus_world_size(group):
return group is not None and torch.distributed.get_world_size(group) > 1
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
@torch.compile
def forward(self, x, cache_x=None, group: torch.distributed.ProcessGroup = None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
if one_plus_world_size(group):
overlap_size = self.kernel_size[-1] // 2 * self.stride[-1]
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
x = F.pad(x, padding)
x = super().forward(x)
if one_plus_world_size(group):
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
return x
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
@torch.compile
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
@torch.compile
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim, 3, padding=1),
# nn.Conv2d(dim, dim//2, 3, padding=1)
)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
@torch.compile
def forward(self, x, feat_cache=None, feat_idx=[0], group: torch.distributed.ProcessGroup = None):
if one_plus_world_size(group):
if self.mode in ["upsample3d", "upsample2d"]:
overlap_size = 1
elif self.mode in ["downsample3d", "downsample2d"]:
overlap_size = 2
else:
overlap_size = 0
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
if one_plus_world_size(group):
if self.mode in ["upsample3d", "upsample2d"]:
overlap_size = overlap_size * 2
elif self.mode in ["downsample3d", "downsample2d"]:
overlap_size = overlap_size // 2
else:
overlap_size = overlap_size
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
return x
def init_weight(self, conv):
conv_weight = conv.weight.detach().clone()
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
conv.weight = nn.Parameter(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data.detach().clone()
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
conv.weight = nn.Parameter(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
@torch.compile
def forward(self, x, feat_cache=None, feat_idx=[0], group: torch.distributed.ProcessGroup = None):
if one_plus_world_size(group):
overlap_size = 2
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
x = x + h
if one_plus_world_size(group):
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
return x
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
@torch.compile
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
x = x + identity
return x
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(x, "b c f (h q) (w r) -> b (c r q) f h w", q=patch_size, r=patch_size)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(x, "b (c r q) f h w -> b c f (h q) (w r)", q=patch_size, r=patch_size)
return x
class AvgDown3D(nn.Module):
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B, C, T // self.factor_t, self.factor_t, H // self.factor_s, self.factor_s, W // self.factor_s, self.factor_s
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * self.factor, T // self.factor_t, H // self.factor_s, W // self.factor_s)
x = x.view(B, self.out_channels, self.group_size, T // self.factor_t, H // self.factor_s, W // self.factor_s)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
@torch.compile
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(x.size(0), self.out_channels, self.factor_t, self.factor_s, self.factor_s, x.size(2), x.size(3), x.size(4))
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0), self.out_channels, x.size(2) * self.factor_t, x.size(4) * self.factor_s, x.size(6) * self.factor_s
)
if first_chunk:
x = x[:, :, self.factor_t - 1 :, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim, out_dim, factor_t=2 if temperal_downsample else 1, factor_s=2 if down_flag else 1
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
@torch.compile
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(in_dim, out_dim, factor_t=2 if temperal_upsample else 1, factor_s=2 if up_flag else 1)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
@torch.compile
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, group: torch.distributed.ProcessGroup = None):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx, group=group)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
)
)
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
)
# # output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
@torch.compile
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1,
)
)
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 12, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, group: torch.distributed.ProcessGroup = None):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx], group=group)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x, group=group)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx, group=group)
else:
x = layer(x)
# upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk, group=group)
else:
x = layer(x, group=group)
# head
if one_plus_world_size(group):
overlap_size = self.head[2].kernel_size[-1] // 2 * self.head[2].stride[-1]
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
if one_plus_world_size(group):
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(
self,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
def forward(self, x, scale=[0, 1]):
mu = self.encode(x, scale)
x_recon = self.decode(mu, scale)
return x_recon, mu
def encode(self, x, scale):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale, group: torch.distributed.ProcessGroup = None):
self.clear_cache()
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z, group=group)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True, group=group
)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, group=group)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
# params
cfg = dict(
dim=dim,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, True],
dropout=0.0,
)
cfg.update(**kwargs)
# init model
with torch.device("meta"):
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f"loading {pretrained_path}")
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
return model
class Wan2_2_VAE:
def __init__(
self,
z_dim=48,
c_dim=160,
vae_pth=None,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dtype=torch.float,
device="cuda",
):
self.dtype = dtype
self.device = device
self.mean = torch.tensor(
[
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
],
dtype=dtype,
device=device,
)
self.std = torch.tensor(
[
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
],
dtype=dtype,
device=device,
)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.vae = (
_video_vae(
pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample
)
.eval()
.requires_grad_(False)
.to(device)
)
def encode(self, video):
return self.vae.encode(video, self.scale).float()
def to(self, *args, **kwargs):
self.mean = self.mean.to(*args, **kwargs)
self.std = self.std.to(*args, **kwargs)
self.scale = [self.mean, 1.0 / self.std]
self.vae = self.vae.to(*args, **kwargs)
return self
def decode(self, z, group: torch.distributed.ProcessGroup = None):
return self.vae.decode(z, self.scale, group=group).float().clamp_(-1, 1)