""" Latent adapter model: maps LTX-2 latent space (128ch) → WAN 2.1 latent space (16ch). Handles: - Channel reduction: 128 → 16 - Spatial upsampling: 4× (LTX uses 32× spatial downscale, WAN uses 8×) - Temporal upsampling: ~2× (LTX uses 8× temporal downscale, WAN uses 4×) """ import torch import torch.nn as nn import torch.nn.functional as F class CausalConv3d(nn.Module): """3D Convolution with causal padding in the temporal dimension.""" def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, **kwargs): super().__init__() self.temporal_pad = kernel_size[0] - 1 if isinstance(kernel_size, tuple) else kernel_size - 1 spatial_pad = padding[1] if isinstance(padding, tuple) else padding self.conv = nn.Conv3d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, spatial_pad, spatial_pad), **kwargs ) def forward(self, x): # x is (B, C, T, H, W) # Pad temporal dimension (left: temporal_pad, right: 0) if self.temporal_pad > 0: x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0)) return self.conv(x) class CausalConvTranspose1d(nn.Module): """Causal temporal upsampling via ConvTranspose3d.""" def __init__(self, channels, stride=2): super().__init__() self.conv = nn.ConvTranspose3d( channels, channels, kernel_size=(3,1,1), stride=(stride,1,1), padding=(0,0,0), # no padding, we'll crop ) self.stride = stride def forward(self, x): x = self.conv(x) # Perfect Causal Trim for T=15 -> T=30, T=4 -> T=6 # conv outputs 2T_in + 1. We drop first 2 frames and last 1 frame to yield 2*(T_in - 1) # This aligns dependencies perfectly: [X0,X1] -> [X1] -> [X1,X2] -> [X2] ... x = x[:, :, 2:-1] return x class CausalGroupNorm(nn.Module): """GroupNorm that treats Time as part of the Batch dimension to prevent temporal leakage.""" def __init__(self, num_groups, num_channels): super().__init__() self.gn = nn.GroupNorm(num_groups, num_channels) def forward(self, x): B, C, T, H, W = x.shape # Move T to batch dimension: (B, C, T, H, W) -> (B, T, C, H, W) -> (B*T, C, H, W) x_reshaped = x.transpose(1, 2).contiguous().view(B * T, C, H, W) x_norm = self.gn(x_reshaped) # Revert back: (B*T, C, H, W) -> (B, T, C, H, W) -> (B, C, T, H, W) return x_norm.view(B, T, C, H, W).transpose(1, 2).contiguous() class CausalResBlock3d(nn.Module): """3D residual block with group norm and causal temporal convolution.""" def __init__(self, channels): super().__init__() self.net = nn.Sequential( CausalGroupNorm(min(16, channels), channels), nn.SiLU(), CausalConv3d(channels, channels, kernel_size=3, padding=1), CausalGroupNorm(min(16, channels), channels), nn.SiLU(), CausalConv3d(channels, channels, kernel_size=3, padding=1), ) def forward(self, x): return x + self.net(x) class ResBlock3d(nn.Module): """3D residual block with group norm.""" def __init__(self, channels): super().__init__() self.net = nn.Sequential( nn.GroupNorm(min(16, channels), channels), nn.SiLU(), nn.Conv3d(channels, channels, kernel_size=3, padding=1), nn.GroupNorm(min(16, channels), channels), nn.SiLU(), nn.Conv3d(channels, channels, kernel_size=3, padding=1), ) def forward(self, x): return x + self.net(x) class UpsampleBlock3d(nn.Module): """Spatial upsample 2× + channel change + residual blocks.""" def __init__(self, in_ch, out_ch, n_res=2, spatial_up=True, temporal_up=False, use_conv_transpose=False, causal=False): super().__init__() layers = [] if spatial_up: scale = (2, 2, 2) if temporal_up else (1, 2, 2) if use_conv_transpose: layers.append(nn.ConvTranspose3d(in_ch, in_ch, kernel_size=scale, stride=scale)) else: layers.append(nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=False)) if causal: layers.append(CausalConv3d(in_ch, out_ch, kernel_size=3, padding=1)) for _ in range(n_res): layers.append(CausalResBlock3d(out_ch)) else: layers.append(nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1)) for _ in range(n_res): layers.append(ResBlock3d(out_ch)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class LatentAdapterV3(nn.Module): """ V3 architecture: ~3.37M params. Maps LTX-2 latents → WAN 2.1 latents. """ def __init__(self): super().__init__() # Initial feature extraction: 128 → 128 self.input_conv = nn.Sequential( nn.Conv3d(128, 128, kernel_size=3, padding=1), ResBlock3d(128), ResBlock3d(128), ) # Stage 1: Channel reduce 128 → 64 + spatial 2× upsample self.up1 = UpsampleBlock3d(128, 64, n_res=2, spatial_up=True) # Stage 2: 64 → 32 + spatial 2× upsample (total 4×) self.up2 = UpsampleBlock3d(64, 32, n_res=2, spatial_up=True) # Stage 3: 32 → 16 refinement (no upsample) self.output_block = nn.Sequential( nn.Conv3d(32, 16, kernel_size=3, padding=1), ResBlock3d(16), ResBlock3d(16), nn.Conv3d(16, 16, kernel_size=3, padding=1), ) def forward(self, z_ltx, target_shape=None): z = self.input_conv(z_ltx) z = self.up1(z) z = self.up2(z) z = self.output_block(z) if target_shape is not None: if z.shape[2:] != target_shape: z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False) return z def param_count(self): return sum(p.numel() for p in self.parameters()) class LatentAdapter(nn.Module): """ Maps LTX-2 latents → WAN 2.1 latents. Input: (B, 128, T_ltx, H_ltx, W_ltx) Output: (B, 16, T_wan, H_wan, W_wan) Spatial: H_wan = H_ltx * 4, W_wan = W_ltx * 4 Temporal: handled by F.interpolate at the end V4 architecture: ~14.4M params. Wider channels and 4 res blocks per stage. """ def __init__(self): super().__init__() # Initial feature extraction: 128 → 256 self.input_conv = nn.Sequential( nn.Conv3d(128, 256, kernel_size=3, padding=1), ResBlock3d(256), ResBlock3d(256), ResBlock3d(256), ResBlock3d(256), ) # Stage 1: Channel reduce 256 → 128 + spatial 2× upsample self.up1 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True) # Stage 2: 128 → 64 + spatial 2× upsample (total 4×) self.up2 = UpsampleBlock3d(128, 64, n_res=4, spatial_up=True) # Stage 3: 64 → 16 refinement (no upsample) self.output_block = nn.Sequential( nn.Conv3d(64, 32, kernel_size=3, padding=1), ResBlock3d(32), ResBlock3d(32), nn.Conv3d(32, 16, kernel_size=3, padding=1), ) def forward(self, z_ltx, target_shape=None): """ Args: z_ltx: (B, 128, T, H, W) LTX-2 latent target_shape: optional (T_target, H_target, W_target) to match exact WAN dimensions """ z = self.input_conv(z_ltx) z = self.up1(z) z = self.up2(z) z = self.output_block(z) # Interpolate to exact target dimensions if provided if target_shape is not None: if z.shape[2:] != target_shape: z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False) return z def param_count(self): return sum(p.numel() for p in self.parameters()) class LatentAdapterV6(nn.Module): """ V6 architecture: ~81.2M params. Maps LTX-2 latents → WAN 2.1 latents. """ def __init__(self): super().__init__() # Initial feature extraction: 128 → 512 self.input_conv = nn.Sequential( nn.Conv3d(128, 512, kernel_size=3, padding=1), ResBlock3d(512), ResBlock3d(512), ResBlock3d(512), ResBlock3d(512), ) # Stage 1: Channel reduce 512 → 256 + spatial 2× upsample self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True) # Stage 2: 256 → 128 + spatial 2× upsample (total 4×) self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True) # Stage 3: 128 → 16 refinement (no upsample) self.output_block = nn.Sequential( nn.Conv3d(128, 64, kernel_size=3, padding=1), ResBlock3d(64), ResBlock3d(64), nn.Conv3d(64, 16, kernel_size=3, padding=1), ) def forward(self, z_ltx, target_shape=None): z = self.input_conv(z_ltx) z = self.up1(z) z = self.up2(z) z = self.output_block(z) if target_shape is not None: if z.shape[2:] != target_shape: z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False) return z def param_count(self): return sum(p.numel() for p in self.parameters()) class LatentAdapterV6_3(nn.Module): """ V6.3 architecture: ~81.8M params. Uses ConvTranspose3d for upsampling instead of trilinear interpolation. Maps LTX-2 latents → WAN 2.1 latents. """ def __init__(self): super().__init__() # Initial feature extraction: 128 → 512 self.input_conv = nn.Sequential( nn.Conv3d(128, 512, kernel_size=3, padding=1), ResBlock3d(512), ResBlock3d(512), ResBlock3d(512), ResBlock3d(512), ) # Stage 1: Channel reduce 512 → 256 + spatial 2× upsample via transposed convolution self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True) # Stage 2: 256 → 128 + spatial 2× upsample via transposed convolution self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True) # Stage 3: 128 → 16 refinement (no upsample) self.output_block = nn.Sequential( nn.Conv3d(128, 64, kernel_size=3, padding=1), ResBlock3d(64), ResBlock3d(64), nn.Conv3d(64, 16, kernel_size=3, padding=1), ) def forward(self, z_ltx, target_shape=None): z = self.input_conv(z_ltx) z = self.up1(z) z = self.up2(z) z = self.output_block(z) if target_shape is not None: if z.shape[2:] != target_shape: z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False) return z def param_count(self): return sum(p.numel() for p in self.parameters()) class LatentAdapterV6_4(nn.Module): """ V6.4 architecture: ~81.8M params. Uses ConvTranspose3d for upsampling and Causal 3D convolutions to entirely eliminate future frame ghosting. Maps LTX-2 latents → WAN 2.1 latents. """ def __init__(self): super().__init__() # Initial feature extraction: 128 → 512 self.input_conv = nn.Sequential( CausalConv3d(128, 512, kernel_size=3, padding=1), CausalResBlock3d(512), CausalResBlock3d(512), CausalResBlock3d(512), CausalResBlock3d(512), ) # Stage 1: Channel reduce 512 → 256 + spatial 2× upsample via transposed convolution self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True) # Stage 2: 256 → 128 + spatial 2× upsample via transposed convolution self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True) # Stage 3: 128 → 16 refinement (no upsample) self.output_block = nn.Sequential( CausalConv3d(128, 64, kernel_size=3, padding=1), CausalResBlock3d(64), CausalResBlock3d(64), CausalConv3d(64, 16, kernel_size=3, padding=1), ) def forward(self, z_ltx, target_shape=None): z = self.input_conv(z_ltx) z = self.up1(z) z = self.up2(z) z = self.output_block(z) if target_shape is not None: if z.shape[2:] != target_shape: z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False) return z def param_count(self): return sum(p.numel() for p in self.parameters()) class LatentAdapterV6_5(nn.Module): """ V6.5 architecture: ~83M params. Uses Causal Temporal Convolutions, learned Temporal ConvTranspose3d upsampling (15->30), and an explicit learned 'Option B' anchor projection (T_0) to reconstruct realistic 31-frame outputs. """ def __init__(self): super().__init__() # Initial feature extraction: 128 → 512 (T=15) self.input_conv = nn.Sequential( CausalConv3d(128, 512, kernel_size=3, padding=1), CausalResBlock3d(512), CausalResBlock3d(512), CausalResBlock3d(512), CausalResBlock3d(512), ) # Temporal 2x Upsample: T=15 -> T=30 self.temporal_up = nn.Sequential( CausalConvTranspose1d(512, stride=2), CausalResBlock3d(512), CausalResBlock3d(512), ) # Stage 1: Channel reduce 512 → 256 + spatial 2× upsample via transposed convolution self.up1 = UpsampleBlock3d(512, 256, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True) # Stage 2: 256 → 128 + spatial 2× upsample via transposed convolution self.up2 = UpsampleBlock3d(256, 128, n_res=4, spatial_up=True, use_conv_transpose=True, causal=True) # Stage 3: 128 → 16 refinement (no upsample) self.output_block = nn.Sequential( CausalConv3d(128, 64, kernel_size=3, padding=1), CausalResBlock3d(64), CausalResBlock3d(64), CausalConv3d(64, 16, kernel_size=3, padding=1), ) # Option B Anchor Projection: LTX T_0 (128ch) -> WAN T_0 (16ch, 4x spatial) self.anchor_proj = nn.Sequential( nn.ConvTranspose3d(128, 64, kernel_size=(1, 4, 4), stride=(1, 4, 4)), nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), nn.SiLU(), nn.Conv3d(64, 16, kernel_size=(1, 3, 3), padding=(0, 1, 1)) ) def forward(self, z_ltx, target_shape=None): # 1. Option B Anchor Projection: grab frame 0 and project it independently anchor = self.anchor_proj(z_ltx[:, :, :1]) # (B, 16, 1, H*4, W*4) # 2. Main Network (T_ltx -> (T_ltx-1)*2) z = self.input_conv(z_ltx) z = self.temporal_up(z) assert z.shape[2] == (z_ltx.shape[2] - 1) * 2, f"temporal_up output T={z.shape[2]}, expected {(z_ltx.shape[2] - 1) * 2}" z = self.up1(z) z = self.up2(z) z = self.output_block(z) # (B, 16, 30, H*4, W*4) # 3. Concatenate Anchor T=1 + Main T=30 -> T=31 z = torch.cat([anchor, z], dim=2) if target_shape is not None: if z.shape[2:] != target_shape: z = F.interpolate(z, size=target_shape, mode='trilinear', align_corners=False) return z def param_count(self): return sum(p.numel() for p in self.parameters()) class LatentAdapterV6_6(LatentAdapterV6_5): """ V6.6 architecture: Identical to 6.5 (~83M params), but the training framework does not truncate the identity anchor slice out of the loss. """ pass if __name__ == "__main__": model = LatentAdapter() print(f"Parameter count: {model.param_count():,}") print(f"Model size (fp32): {model.param_count() * 4 / 1024 / 1024:.1f} MB") # Test with 480×704 video (portrait 464x688 → short_side=480) # LTX: 32× spatial, 8× temporal → (128, 4, 15, 22) for 25 frames z_ltx = torch.randn(1, 128, 4, 15, 22) # WAN: 8× spatial, 4× temporal → (16, 7, 60, 88) z_wan_target = (7, 60, 88) z_out = model(z_ltx, target_shape=z_wan_target) print(f"\nPortrait test (480×704, 25 frames):") print(f" Input: {z_ltx.shape}") print(f" Output: {z_out.shape}") print(f" Target: (1, 16, {z_wan_target[0]}, {z_wan_target[1]}, {z_wan_target[2]})") # Test with landscape z_ltx2 = torch.randn(1, 128, 4, 15, 26) z_wan_target2 = (7, 60, 104) z_out2 = model(z_ltx2, target_shape=z_wan_target2) print(f"\nLandscape test (480×832, 25 frames):") print(f" Input: {z_ltx2.shape}") print(f" Output: {z_out2.shape}") print(f" Target: (1, 16, {z_wan_target2[0]}, {z_wan_target2[1]}, {z_wan_target2[2]})")