import torch import torch.nn as nn from diffsynth.diffusion.base_pipeline import PipelineUnit class WanVideoUnit_GBufferEncoder(PipelineUnit): """ Encode G-buffer videos (depth, normal, albedo, etc.) via VAE and concat to y. Each G-buffer modality is a list of PIL Images (video frames). gbuffer_videos: list of lists, e.g. [[depth_frame0, ...], [normal_frame0, ...]] Output shape: [1, N*16, T, H, W] where N = number of G-buffer modalities. """ def __init__(self): super().__init__( input_params=("gbuffer_videos", "y", "tiled", "tile_size", "tile_stride"), output_params=("y",), onload_model_names=("vae",) ) def process(self, pipe, gbuffer_videos, y, tiled, tile_size, tile_stride): if gbuffer_videos is None: return {} pipe.load_models_to_device(self.onload_model_names) all_latents = [] for gbuffer_video in gbuffer_videos: video_tensor = pipe.preprocess_video(gbuffer_video) latent = pipe.vae.encode( video_tensor, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride ).to(dtype=pipe.torch_dtype, device=pipe.device) all_latents.append(latent) gbuffer_latents = torch.cat(all_latents, dim=1) # [1, N*16, T, H, W] if y is not None: gbuffer_latents = torch.cat([y, gbuffer_latents], dim=1) return {"y": gbuffer_latents} def expand_patch_embedding(pipe, num_gbuffers): """ Expand DiT's patch_embedding Conv3d to accept additional G-buffer latent channels. New channels are zero-initialized so the model starts equivalent to the original. """ if num_gbuffers <= 0: return dit = pipe.dit old_conv = dit.patch_embedding old_weight = old_conv.weight # [out_channels, in_channels, *kernel_size] old_bias = old_conv.bias extra_channels = num_gbuffers * 16 # 16 VAE latent channels per modality # Skip if already expanded (e.g., loading from a GBuffer-trained checkpoint) # If current in_channels > extra_channels, it already includes base + gbuffer channels if old_weight.shape[1] > extra_channels: print(f"patch_embedding already has {old_weight.shape[1]} input channels (> {extra_channels} extra), already expanded. Skipping.") return new_in_dim = old_weight.shape[1] + extra_channels new_conv = nn.Conv3d( new_in_dim, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=old_bias is not None, ) with torch.no_grad(): new_conv.weight.zero_() new_conv.weight[:, :old_weight.shape[1]] = old_weight if old_bias is not None: new_conv.bias.copy_(old_bias) dit.patch_embedding = new_conv.to(dtype=old_weight.dtype, device=old_weight.device) dit.in_dim = new_in_dim print(f"Expanded patch_embedding: {old_weight.shape[1]} -> {new_in_dim} input channels (+{extra_channels} for {num_gbuffers} G-buffers)") def inject_gbuffer_unit(pipe): """ Insert WanVideoUnit_GBufferEncoder into the pipeline's unit list, after ImageEmbedderFused and before FunControl. """ # Skip if already injected for u in pipe.units: if type(u).__name__ == "WanVideoUnit_GBufferEncoder": print("WanVideoUnit_GBufferEncoder already present, skipping injection.") return unit = WanVideoUnit_GBufferEncoder() insert_idx = None for i, u in enumerate(pipe.units): if type(u).__name__ == "WanVideoUnit_ImageEmbedderFused": insert_idx = i + 1 break if insert_idx is not None: pipe.units.insert(insert_idx, unit) else: # Fallback: insert before the last few units pipe.units.insert(-3, unit) print(f"Injected WanVideoUnit_GBufferEncoder at position {insert_idx}")