Ltx2_2_Wan2.1_VAE_Adapter / adapter_model.py
HDHCDev's picture
Initial commit
2701c9d verified
"""
Latent adapter model: maps LTX-2 latent space (128ch) β†’ WAN 2.1 latent space (16ch).
Handles:
- Channel reduction: 128 β†’ 16
- Spatial upsampling: 4Γ— (LTX uses 32Γ— spatial downscale, WAN uses 8Γ—)
- Temporal upsampling: ~2Γ— (LTX uses 8Γ— temporal downscale, WAN uses 4Γ—)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalConv3d(nn.Module):
"""3D Convolution with causal padding in the temporal dimension."""
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, **kwargs):
super().__init__()
self.temporal_pad = kernel_size[0] - 1 if isinstance(kernel_size, tuple) else kernel_size - 1
spatial_pad = padding[1] if isinstance(padding, tuple) else padding
self.conv = nn.Conv3d(
in_channels, out_channels,
kernel_size=kernel_size,
padding=(0, spatial_pad, spatial_pad),
**kwargs
)
def forward(self, x):
# x is (B, C, T, H, W)
# Pad temporal dimension (left: temporal_pad, right: 0)
if self.temporal_pad > 0:
x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0))
return self.conv(x)
class CausalConvTranspose1d(nn.Module):
"""Causal temporal upsampling via ConvTranspose3d."""
def __init__(self, channels, stride=2):
super().__init__()
self.conv = nn.ConvTranspose3d(
channels, channels,
kernel_size=(3,1,1), stride=(stride,1,1),
padding=(0,0,0), # no padding, we'll crop
)
self.stride = stride
def forward(self, x):
x = self.conv(x)
# Perfect Causal Trim for T=15 -> T=30, T=4 -> T=6
# conv outputs 2T_in + 1. We drop first 2 frames and last 1 frame to yield 2*(T_in - 1)
# This aligns dependencies perfectly: [X0,X1] -> [X1] -> [X1,X2] -> [X2] ...
x = x[:, :, 2:-1]
return x
class CausalGroupNorm(nn.Module):
"""GroupNorm that treats Time as part of the Batch dimension to prevent temporal leakage."""
def __init__(self, num_groups, num_channels):
super().__init__()
self.gn = nn.GroupNorm(num_groups, num_channels)
def forward(self, x):
B, C, T, H, W = x.shape
# Move T to batch dimension: (B, C, T, H, W) -> (B, T, C, H, W) -> (B*T, C, H, W)
x_reshaped = x.transpose(1, 2).contiguous().view(B * T, C, H, W)
x_norm = self.gn(x_reshaped)
# Revert back: (B*T, C, H, W) -> (B, T, C, H, W) -> (B, C, T, H, W)
return x_norm.view(B, T, C, H, W).transpose(1, 2).contiguous()
class CausalResBlock3d(nn.Module):
"""3D residual block with group norm and causal temporal convolution."""
def __init__(self, channels):
super().__init__()
self.net = nn.Sequential(
CausalGroupNorm(min(16, channels), channels),
nn.SiLU(),
CausalConv3d(channels, channels, kernel_size=3, padding=1),
CausalGroupNorm(min(16, channels), channels),
nn.SiLU(),
CausalConv3d(channels, channels, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.net(x)
class ResBlock3d(nn.Module):
"""3D residual block with group norm."""
def __init__(self, channels):
super().__init__()
self.net = nn.Sequential(
nn.GroupNorm(min(16, channels), channels),
nn.SiLU(),
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
nn.GroupNorm(min(16, channels), channels),
nn.SiLU(),
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.net(x)
class UpsampleBlock3d(nn.Module):
"""Spatial upsample 2Γ— + channel change + residual blocks."""
def __init__(self, in_ch, out_ch, n_res=2, spatial_up=True, temporal_up=False, use_conv_transpose=False, causal=False):
super().__init__()
layers = []
if spatial_up:
scale = (2, 2, 2) if temporal_up else (1, 2, 2)
if use_conv_transpose:
layers.append(nn.ConvTranspose3d(in_ch, in_ch, kernel_size=scale, stride=scale))
else:
layers.append(nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=False))
if causal:
layers.append(CausalConv3d(in_ch, out_ch, kernel_size=3, padding=1))
for _ in range(n_res):
layers.append(CausalResBlock3d(out_ch))
else:
layers.append(nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1))
for _ in range(n_res):
layers.append(ResBlock3d(out_ch))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class LatentAdapterV3(nn.Module):
"""
V3 architecture: ~3.37M params. Maps LTX-2 latents β†’ WAN 2.1 latents.
"""
def __init__(self):
super().__init__()
# Initial feature extraction: 128 β†’ 128
self.input_conv = nn.Sequential(
nn.Conv3d(128, 128, kernel_size=3, padding=1),
ResBlock3d(128),
ResBlock3d(128),
)
# Stage 1: Channel reduce 128 β†’ 64 + spatial 2Γ— upsample
self.up1 = UpsampleBlock3d(128, 64, n_res=2, spatial_up=True)
# Stage 2: 64 β†’ 32 + spatial 2Γ— upsample (total 4Γ—)
self.up2 = UpsampleBlock3d(64, 32, n_res=2, spatial_up=True)
# Stage 3: 32 β†’ 16 refinement (no upsample)
self.output_block = nn.Sequential(
nn.Conv3d(32, 16, kernel_size=3, padding=1),
ResBlock3d(16),
ResBlock3d(16),
nn.Conv3d(16, 16, kernel_size=3, padding=1),
)
def forward(self, z_ltx, target_shape=None):
z = self.input_conv(z_ltx)
z = self.up1(z)
z = self.up2(z)
z = self.output_block(z)
if target_shape is not None:
if z.shape[2:] != target_shape:
z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
return z
def param_count(self):
return sum(p.numel() for p in self.parameters())
class LatentAdapter(nn.Module):
"""
Maps LTX-2 latents β†’ WAN 2.1 latents.
Input: (B, 128, T_ltx, H_ltx, W_ltx)
Output: (B, 16, T_wan, H_wan, W_wan)
Spatial: H_wan = H_ltx * 4, W_wan = W_ltx * 4
Temporal: handled by F.interpolate at the end
V4 architecture: ~14.4M params. Wider channels and 4 res blocks per stage.
"""
def __init__(self):
super().__init__()
# Initial feature extraction: 128 β†’ 256
self.input_conv = nn.Sequential(
nn.Conv3d(128, 256, kernel_size=3, padding=1),
ResBlock3d(256),
ResBlock3d(256),
ResBlock3d(256),
ResBlock3d(256),
)
# Stage 1: Channel reduce 256 β†’ 128 + spatial 2Γ— upsample
self.up1 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True)
# Stage 2: 128 β†’ 64 + spatial 2Γ— upsample (total 4Γ—)
self.up2 = UpsampleBlock3d(128, 64, n_res=4, spatial_up=True)
# Stage 3: 64 β†’ 16 refinement (no upsample)
self.output_block = nn.Sequential(
nn.Conv3d(64, 32, kernel_size=3, padding=1),
ResBlock3d(32),
ResBlock3d(32),
nn.Conv3d(32, 16, kernel_size=3, padding=1),
)
def forward(self, z_ltx, target_shape=None):
"""
Args:
z_ltx: (B, 128, T, H, W) LTX-2 latent
target_shape: optional (T_target, H_target, W_target) to match exact WAN dimensions
"""
z = self.input_conv(z_ltx)
z = self.up1(z)
z = self.up2(z)
z = self.output_block(z)
# Interpolate to exact target dimensions if provided
if target_shape is not None:
if z.shape[2:] != target_shape:
z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
return z
def param_count(self):
return sum(p.numel() for p in self.parameters())
class LatentAdapterV6(nn.Module):
"""
V6 architecture: ~81.2M params. Maps LTX-2 latents β†’ WAN 2.1 latents.
"""
def __init__(self):
super().__init__()
# Initial feature extraction: 128 β†’ 512
self.input_conv = nn.Sequential(
nn.Conv3d(128, 512, kernel_size=3, padding=1),
ResBlock3d(512),
ResBlock3d(512),
ResBlock3d(512),
ResBlock3d(512),
)
# Stage 1: Channel reduce 512 β†’ 256 + spatial 2Γ— upsample
self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True)
# Stage 2: 256 β†’ 128 + spatial 2Γ— upsample (total 4Γ—)
self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True)
# Stage 3: 128 β†’ 16 refinement (no upsample)
self.output_block = nn.Sequential(
nn.Conv3d(128, 64, kernel_size=3, padding=1),
ResBlock3d(64),
ResBlock3d(64),
nn.Conv3d(64, 16, kernel_size=3, padding=1),
)
def forward(self, z_ltx, target_shape=None):
z = self.input_conv(z_ltx)
z = self.up1(z)
z = self.up2(z)
z = self.output_block(z)
if target_shape is not None:
if z.shape[2:] != target_shape:
z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
return z
def param_count(self):
return sum(p.numel() for p in self.parameters())
class LatentAdapterV6_3(nn.Module):
"""
V6.3 architecture: ~81.8M params. Uses ConvTranspose3d for upsampling instead of trilinear interpolation.
Maps LTX-2 latents β†’ WAN 2.1 latents.
"""
def __init__(self):
super().__init__()
# Initial feature extraction: 128 β†’ 512
self.input_conv = nn.Sequential(
nn.Conv3d(128, 512, kernel_size=3, padding=1),
ResBlock3d(512),
ResBlock3d(512),
ResBlock3d(512),
ResBlock3d(512),
)
# Stage 1: Channel reduce 512 β†’ 256 + spatial 2Γ— upsample via transposed convolution
self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True)
# Stage 2: 256 β†’ 128 + spatial 2Γ— upsample via transposed convolution
self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True)
# Stage 3: 128 β†’ 16 refinement (no upsample)
self.output_block = nn.Sequential(
nn.Conv3d(128, 64, kernel_size=3, padding=1),
ResBlock3d(64),
ResBlock3d(64),
nn.Conv3d(64, 16, kernel_size=3, padding=1),
)
def forward(self, z_ltx, target_shape=None):
z = self.input_conv(z_ltx)
z = self.up1(z)
z = self.up2(z)
z = self.output_block(z)
if target_shape is not None:
if z.shape[2:] != target_shape:
z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
return z
def param_count(self):
return sum(p.numel() for p in self.parameters())
class LatentAdapterV6_4(nn.Module):
"""
V6.4 architecture: ~81.8M params. Uses ConvTranspose3d for upsampling and Causal 3D convolutions
to entirely eliminate future frame ghosting. Maps LTX-2 latents β†’ WAN 2.1 latents.
"""
def __init__(self):
super().__init__()
# Initial feature extraction: 128 β†’ 512
self.input_conv = nn.Sequential(
CausalConv3d(128, 512, kernel_size=3, padding=1),
CausalResBlock3d(512),
CausalResBlock3d(512),
CausalResBlock3d(512),
CausalResBlock3d(512),
)
# Stage 1: Channel reduce 512 β†’ 256 + spatial 2Γ— upsample via transposed convolution
self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
# Stage 2: 256 β†’ 128 + spatial 2Γ— upsample via transposed convolution
self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
# Stage 3: 128 β†’ 16 refinement (no upsample)
self.output_block = nn.Sequential(
CausalConv3d(128, 64, kernel_size=3, padding=1),
CausalResBlock3d(64),
CausalResBlock3d(64),
CausalConv3d(64, 16, kernel_size=3, padding=1),
)
def forward(self, z_ltx, target_shape=None):
z = self.input_conv(z_ltx)
z = self.up1(z)
z = self.up2(z)
z = self.output_block(z)
if target_shape is not None:
if z.shape[2:] != target_shape:
z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
return z
def param_count(self):
return sum(p.numel() for p in self.parameters())
class LatentAdapterV6_5(nn.Module):
"""
V6.5 architecture: ~83M params.
Uses Causal Temporal Convolutions, learned Temporal ConvTranspose3d upsampling (15->30),
and an explicit learned 'Option B' anchor projection (T_0) to reconstruct realistic 31-frame outputs.
"""
def __init__(self):
super().__init__()
# Initial feature extraction: 128 β†’ 512 (T=15)
self.input_conv = nn.Sequential(
CausalConv3d(128, 512, kernel_size=3, padding=1),
CausalResBlock3d(512),
CausalResBlock3d(512),
CausalResBlock3d(512),
CausalResBlock3d(512),
)
# Temporal 2x Upsample: T=15 -> T=30
self.temporal_up = nn.Sequential(
CausalConvTranspose1d(512, stride=2),
CausalResBlock3d(512),
CausalResBlock3d(512),
)
# Stage 1: Channel reduce 512 β†’ 256 + spatial 2Γ— upsample via transposed convolution
self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
# Stage 2: 256 β†’ 128 + spatial 2Γ— upsample via transposed convolution
self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
# Stage 3: 128 β†’ 16 refinement (no upsample)
self.output_block = nn.Sequential(
CausalConv3d(128, 64, kernel_size=3, padding=1),
CausalResBlock3d(64),
CausalResBlock3d(64),
CausalConv3d(64, 16, kernel_size=3, padding=1),
)
# Option B Anchor Projection: LTX T_0 (128ch) -> WAN T_0 (16ch, 4x spatial)
self.anchor_proj = nn.Sequential(
nn.ConvTranspose3d(128, 64, kernel_size=(1, 4, 4), stride=(1, 4, 4)),
nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
nn.SiLU(),
nn.Conv3d(64, 16, kernel_size=(1, 3, 3), padding=(0, 1, 1))
)
def forward(self, z_ltx, target_shape=None):
# 1. Option B Anchor Projection: grab frame 0 and project it independently
anchor = self.anchor_proj(z_ltx[:, :, :1]) # (B, 16, 1, H*4, W*4)
# 2. Main Network (T_ltx -> (T_ltx-1)*2)
z = self.input_conv(z_ltx)
z = self.temporal_up(z)
assert z.shape[2] == (z_ltx.shape[2] - 1) * 2, f"temporal_up output T={z.shape[2]}, expected {(z_ltx.shape[2] - 1) * 2}"
z = self.up1(z)
z = self.up2(z)
z = self.output_block(z) # (B, 16, 30, H*4, W*4)
# 3. Concatenate Anchor T=1 + Main T=30 -> T=31
z = torch.cat([anchor, z], dim=2)
if target_shape is not None:
if z.shape[2:] != target_shape:
z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
return z
def param_count(self):
return sum(p.numel() for p in self.parameters())
class LatentAdapterV6_6(LatentAdapterV6_5):
"""
V6.6 architecture: Identical to 6.5 (~83M params),
but the training framework does not truncate the identity anchor slice out of the loss.
"""
pass
if __name__ == "__main__":
model = LatentAdapter()
print(f"Parameter count: {model.param_count():,}")
print(f"Model size (fp32): {model.param_count() * 4 / 1024 / 1024:.1f} MB")
# Test with 480Γ—704 video (portrait 464x688 β†’ short_side=480)
# LTX: 32Γ— spatial, 8Γ— temporal β†’ (128, 4, 15, 22) for 25 frames
z_ltx = torch.randn(1, 128, 4, 15, 22)
# WAN: 8Γ— spatial, 4Γ— temporal β†’ (16, 7, 60, 88)
z_wan_target = (7, 60, 88)
z_out = model(z_ltx, target_shape=z_wan_target)
print(f"\nPortrait test (480Γ—704, 25 frames):")
print(f" Input: {z_ltx.shape}")
print(f" Output: {z_out.shape}")
print(f" Target: (1, 16, {z_wan_target[0]}, {z_wan_target[1]}, {z_wan_target[2]})")
# Test with landscape
z_ltx2 = torch.randn(1, 128, 4, 15, 26)
z_wan_target2 = (7, 60, 104)
z_out2 = model(z_ltx2, target_shape=z_wan_target2)
print(f"\nLandscape test (480Γ—832, 25 frames):")
print(f" Input: {z_ltx2.shape}")
print(f" Output: {z_out2.shape}")
print(f" Target: (1, 16, {z_wan_target2[0]}, {z_wan_target2[1]}, {z_wan_target2[2]})")