| """
|
| Latent adapter model: maps LTX-2 latent space (128ch) β WAN 2.1 latent space (16ch).
|
|
|
| Handles:
|
| - Channel reduction: 128 β 16
|
| - Spatial upsampling: 4Γ (LTX uses 32Γ spatial downscale, WAN uses 8Γ)
|
| - Temporal upsampling: ~2Γ (LTX uses 8Γ temporal downscale, WAN uses 4Γ)
|
|
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
| class CausalConv3d(nn.Module):
|
| """3D Convolution with causal padding in the temporal dimension."""
|
| def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, **kwargs):
|
| super().__init__()
|
| self.temporal_pad = kernel_size[0] - 1 if isinstance(kernel_size, tuple) else kernel_size - 1
|
| spatial_pad = padding[1] if isinstance(padding, tuple) else padding
|
|
|
| self.conv = nn.Conv3d(
|
| in_channels, out_channels,
|
| kernel_size=kernel_size,
|
| padding=(0, spatial_pad, spatial_pad),
|
| **kwargs
|
| )
|
|
|
| def forward(self, x):
|
|
|
|
|
| if self.temporal_pad > 0:
|
| x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0))
|
| return self.conv(x)
|
|
|
|
|
| class CausalConvTranspose1d(nn.Module):
|
| """Causal temporal upsampling via ConvTranspose3d."""
|
| def __init__(self, channels, stride=2):
|
| super().__init__()
|
| self.conv = nn.ConvTranspose3d(
|
| channels, channels,
|
| kernel_size=(3,1,1), stride=(stride,1,1),
|
| padding=(0,0,0),
|
| )
|
| self.stride = stride
|
|
|
| def forward(self, x):
|
| x = self.conv(x)
|
|
|
|
|
|
|
| x = x[:, :, 2:-1]
|
| return x
|
|
|
|
|
| class CausalGroupNorm(nn.Module):
|
| """GroupNorm that treats Time as part of the Batch dimension to prevent temporal leakage."""
|
| def __init__(self, num_groups, num_channels):
|
| super().__init__()
|
| self.gn = nn.GroupNorm(num_groups, num_channels)
|
|
|
| def forward(self, x):
|
| B, C, T, H, W = x.shape
|
|
|
| x_reshaped = x.transpose(1, 2).contiguous().view(B * T, C, H, W)
|
| x_norm = self.gn(x_reshaped)
|
|
|
| return x_norm.view(B, T, C, H, W).transpose(1, 2).contiguous()
|
|
|
|
|
| class CausalResBlock3d(nn.Module):
|
| """3D residual block with group norm and causal temporal convolution."""
|
|
|
| def __init__(self, channels):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| CausalGroupNorm(min(16, channels), channels),
|
| nn.SiLU(),
|
| CausalConv3d(channels, channels, kernel_size=3, padding=1),
|
| CausalGroupNorm(min(16, channels), channels),
|
| nn.SiLU(),
|
| CausalConv3d(channels, channels, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, x):
|
| return x + self.net(x)
|
|
|
|
|
|
|
| class ResBlock3d(nn.Module):
|
| """3D residual block with group norm."""
|
|
|
| def __init__(self, channels):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.GroupNorm(min(16, channels), channels),
|
| nn.SiLU(),
|
| nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
| nn.GroupNorm(min(16, channels), channels),
|
| nn.SiLU(),
|
| nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, x):
|
| return x + self.net(x)
|
|
|
|
|
| class UpsampleBlock3d(nn.Module):
|
| """Spatial upsample 2Γ + channel change + residual blocks."""
|
|
|
| def __init__(self, in_ch, out_ch, n_res=2, spatial_up=True, temporal_up=False, use_conv_transpose=False, causal=False):
|
| super().__init__()
|
|
|
| layers = []
|
| if spatial_up:
|
| scale = (2, 2, 2) if temporal_up else (1, 2, 2)
|
| if use_conv_transpose:
|
| layers.append(nn.ConvTranspose3d(in_ch, in_ch, kernel_size=scale, stride=scale))
|
| else:
|
| layers.append(nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=False))
|
|
|
| if causal:
|
| layers.append(CausalConv3d(in_ch, out_ch, kernel_size=3, padding=1))
|
| for _ in range(n_res):
|
| layers.append(CausalResBlock3d(out_ch))
|
| else:
|
| layers.append(nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1))
|
| for _ in range(n_res):
|
| layers.append(ResBlock3d(out_ch))
|
|
|
| self.net = nn.Sequential(*layers)
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| class LatentAdapterV3(nn.Module):
|
| """
|
| V3 architecture: ~3.37M params. Maps LTX-2 latents β WAN 2.1 latents.
|
| """
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
| self.input_conv = nn.Sequential(
|
| nn.Conv3d(128, 128, kernel_size=3, padding=1),
|
| ResBlock3d(128),
|
| ResBlock3d(128),
|
| )
|
|
|
|
|
| self.up1 = UpsampleBlock3d(128, 64, n_res=2, spatial_up=True)
|
|
|
|
|
| self.up2 = UpsampleBlock3d(64, 32, n_res=2, spatial_up=True)
|
|
|
|
|
| self.output_block = nn.Sequential(
|
| nn.Conv3d(32, 16, kernel_size=3, padding=1),
|
| ResBlock3d(16),
|
| ResBlock3d(16),
|
| nn.Conv3d(16, 16, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, z_ltx, target_shape=None):
|
| z = self.input_conv(z_ltx)
|
| z = self.up1(z)
|
| z = self.up2(z)
|
| z = self.output_block(z)
|
|
|
| if target_shape is not None:
|
| if z.shape[2:] != target_shape:
|
| z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
|
|
|
| return z
|
|
|
| def param_count(self):
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
| class LatentAdapter(nn.Module):
|
| """
|
| Maps LTX-2 latents β WAN 2.1 latents.
|
|
|
| Input: (B, 128, T_ltx, H_ltx, W_ltx)
|
| Output: (B, 16, T_wan, H_wan, W_wan)
|
|
|
| Spatial: H_wan = H_ltx * 4, W_wan = W_ltx * 4
|
| Temporal: handled by F.interpolate at the end
|
|
|
| V4 architecture: ~14.4M params. Wider channels and 4 res blocks per stage.
|
| """
|
|
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
| self.input_conv = nn.Sequential(
|
| nn.Conv3d(128, 256, kernel_size=3, padding=1),
|
| ResBlock3d(256),
|
| ResBlock3d(256),
|
| ResBlock3d(256),
|
| ResBlock3d(256),
|
| )
|
|
|
|
|
| self.up1 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True)
|
|
|
|
|
| self.up2 = UpsampleBlock3d(128, 64, n_res=4, spatial_up=True)
|
|
|
|
|
| self.output_block = nn.Sequential(
|
| nn.Conv3d(64, 32, kernel_size=3, padding=1),
|
| ResBlock3d(32),
|
| ResBlock3d(32),
|
| nn.Conv3d(32, 16, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, z_ltx, target_shape=None):
|
| """
|
| Args:
|
| z_ltx: (B, 128, T, H, W) LTX-2 latent
|
| target_shape: optional (T_target, H_target, W_target) to match exact WAN dimensions
|
| """
|
| z = self.input_conv(z_ltx)
|
| z = self.up1(z)
|
| z = self.up2(z)
|
| z = self.output_block(z)
|
|
|
|
|
| if target_shape is not None:
|
| if z.shape[2:] != target_shape:
|
| z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
|
|
|
| return z
|
|
|
| def param_count(self):
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
| class LatentAdapterV6(nn.Module):
|
| """
|
| V6 architecture: ~81.2M params. Maps LTX-2 latents β WAN 2.1 latents.
|
| """
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
| self.input_conv = nn.Sequential(
|
| nn.Conv3d(128, 512, kernel_size=3, padding=1),
|
| ResBlock3d(512),
|
| ResBlock3d(512),
|
| ResBlock3d(512),
|
| ResBlock3d(512),
|
| )
|
|
|
|
|
| self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True)
|
|
|
|
|
| self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True)
|
|
|
|
|
| self.output_block = nn.Sequential(
|
| nn.Conv3d(128, 64, kernel_size=3, padding=1),
|
| ResBlock3d(64),
|
| ResBlock3d(64),
|
| nn.Conv3d(64, 16, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, z_ltx, target_shape=None):
|
| z = self.input_conv(z_ltx)
|
| z = self.up1(z)
|
| z = self.up2(z)
|
| z = self.output_block(z)
|
|
|
| if target_shape is not None:
|
| if z.shape[2:] != target_shape:
|
| z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
|
|
|
| return z
|
|
|
| def param_count(self):
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
| class LatentAdapterV6_3(nn.Module):
|
| """
|
| V6.3 architecture: ~81.8M params. Uses ConvTranspose3d for upsampling instead of trilinear interpolation.
|
| Maps LTX-2 latents β WAN 2.1 latents.
|
| """
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
| self.input_conv = nn.Sequential(
|
| nn.Conv3d(128, 512, kernel_size=3, padding=1),
|
| ResBlock3d(512),
|
| ResBlock3d(512),
|
| ResBlock3d(512),
|
| ResBlock3d(512),
|
| )
|
|
|
|
|
| self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True)
|
|
|
|
|
| self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True)
|
|
|
|
|
| self.output_block = nn.Sequential(
|
| nn.Conv3d(128, 64, kernel_size=3, padding=1),
|
| ResBlock3d(64),
|
| ResBlock3d(64),
|
| nn.Conv3d(64, 16, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, z_ltx, target_shape=None):
|
| z = self.input_conv(z_ltx)
|
| z = self.up1(z)
|
| z = self.up2(z)
|
| z = self.output_block(z)
|
|
|
| if target_shape is not None:
|
| if z.shape[2:] != target_shape:
|
| z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
|
|
|
| return z
|
|
|
| def param_count(self):
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
| class LatentAdapterV6_4(nn.Module):
|
| """
|
| V6.4 architecture: ~81.8M params. Uses ConvTranspose3d for upsampling and Causal 3D convolutions
|
| to entirely eliminate future frame ghosting. Maps LTX-2 latents β WAN 2.1 latents.
|
| """
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
| self.input_conv = nn.Sequential(
|
| CausalConv3d(128, 512, kernel_size=3, padding=1),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| )
|
|
|
|
|
| self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
|
|
|
|
|
| self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
|
|
|
|
|
| self.output_block = nn.Sequential(
|
| CausalConv3d(128, 64, kernel_size=3, padding=1),
|
| CausalResBlock3d(64),
|
| CausalResBlock3d(64),
|
| CausalConv3d(64, 16, kernel_size=3, padding=1),
|
| )
|
|
|
| def forward(self, z_ltx, target_shape=None):
|
| z = self.input_conv(z_ltx)
|
| z = self.up1(z)
|
| z = self.up2(z)
|
| z = self.output_block(z)
|
|
|
| if target_shape is not None:
|
| if z.shape[2:] != target_shape:
|
| z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
|
|
|
| return z
|
|
|
| def param_count(self):
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
| class LatentAdapterV6_5(nn.Module):
|
| """
|
| V6.5 architecture: ~83M params.
|
| Uses Causal Temporal Convolutions, learned Temporal ConvTranspose3d upsampling (15->30),
|
| and an explicit learned 'Option B' anchor projection (T_0) to reconstruct realistic 31-frame outputs.
|
| """
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
| self.input_conv = nn.Sequential(
|
| CausalConv3d(128, 512, kernel_size=3, padding=1),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| )
|
|
|
|
|
| self.temporal_up = nn.Sequential(
|
| CausalConvTranspose1d(512, stride=2),
|
| CausalResBlock3d(512),
|
| CausalResBlock3d(512),
|
| )
|
|
|
|
|
| self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
|
|
|
|
|
| self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True)
|
|
|
|
|
| self.output_block = nn.Sequential(
|
| CausalConv3d(128, 64, kernel_size=3, padding=1),
|
| CausalResBlock3d(64),
|
| CausalResBlock3d(64),
|
| CausalConv3d(64, 16, kernel_size=3, padding=1),
|
| )
|
|
|
|
|
| self.anchor_proj = nn.Sequential(
|
| nn.ConvTranspose3d(128, 64, kernel_size=(1, 4, 4), stride=(1, 4, 4)),
|
| nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
|
| nn.SiLU(),
|
| nn.Conv3d(64, 16, kernel_size=(1, 3, 3), padding=(0, 1, 1))
|
| )
|
|
|
| def forward(self, z_ltx, target_shape=None):
|
|
|
| anchor = self.anchor_proj(z_ltx[:, :, :1])
|
|
|
|
|
| z = self.input_conv(z_ltx)
|
| z = self.temporal_up(z)
|
| assert z.shape[2] == (z_ltx.shape[2] - 1) * 2, f"temporal_up output T={z.shape[2]}, expected {(z_ltx.shape[2] - 1) * 2}"
|
| z = self.up1(z)
|
| z = self.up2(z)
|
| z = self.output_block(z)
|
|
|
|
|
| z = torch.cat([anchor, z], dim=2)
|
|
|
| if target_shape is not None:
|
| if z.shape[2:] != target_shape:
|
| z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False)
|
|
|
| return z
|
|
|
| def param_count(self):
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
| class LatentAdapterV6_6(LatentAdapterV6_5):
|
| """
|
| V6.6 architecture: Identical to 6.5 (~83M params),
|
| but the training framework does not truncate the identity anchor slice out of the loss.
|
| """
|
| pass
|
|
|
|
|
| if __name__ == "__main__":
|
| model = LatentAdapter()
|
| print(f"Parameter count: {model.param_count():,}")
|
| print(f"Model size (fp32): {model.param_count() * 4 / 1024 / 1024:.1f} MB")
|
|
|
|
|
|
|
| z_ltx = torch.randn(1, 128, 4, 15, 22)
|
|
|
| z_wan_target = (7, 60, 88)
|
|
|
| z_out = model(z_ltx, target_shape=z_wan_target)
|
| print(f"\nPortrait test (480Γ704, 25 frames):")
|
| print(f" Input: {z_ltx.shape}")
|
| print(f" Output: {z_out.shape}")
|
| print(f" Target: (1, 16, {z_wan_target[0]}, {z_wan_target[1]}, {z_wan_target[2]})")
|
|
|
|
|
| z_ltx2 = torch.randn(1, 128, 4, 15, 26)
|
| z_wan_target2 = (7, 60, 104)
|
| z_out2 = model(z_ltx2, target_shape=z_wan_target2)
|
| print(f"\nLandscape test (480Γ832, 25 frames):")
|
| print(f" Input: {z_ltx2.shape}")
|
| print(f" Output: {z_out2.shape}")
|
| print(f" Target: (1, 16, {z_wan_target2[0]}, {z_wan_target2[1]}, {z_wan_target2[2]})")
|
|
|