File size: 3,978 Bytes
f75b9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85b1487
 
 
 
 
 
 
 
 
f75b9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from diffsynth.diffusion.base_pipeline import PipelineUnit


class WanVideoUnit_GBufferEncoder(PipelineUnit):
    """
    Encode G-buffer videos (depth, normal, albedo, etc.) via VAE and concat to y.
    Each G-buffer modality is a list of PIL Images (video frames).
    gbuffer_videos: list of lists, e.g. [[depth_frame0, ...], [normal_frame0, ...]]
    Output shape: [1, N*16, T, H, W] where N = number of G-buffer modalities.
    """
    def __init__(self):
        super().__init__(
            input_params=("gbuffer_videos", "y", "tiled", "tile_size", "tile_stride"),
            output_params=("y",),
            onload_model_names=("vae",)
        )

    def process(self, pipe, gbuffer_videos, y, tiled, tile_size, tile_stride):
        if gbuffer_videos is None:
            return {}
        pipe.load_models_to_device(self.onload_model_names)
        all_latents = []
        for gbuffer_video in gbuffer_videos:
            video_tensor = pipe.preprocess_video(gbuffer_video)
            latent = pipe.vae.encode(
                video_tensor, device=pipe.device,
                tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
            ).to(dtype=pipe.torch_dtype, device=pipe.device)
            all_latents.append(latent)
        gbuffer_latents = torch.cat(all_latents, dim=1)  # [1, N*16, T, H, W]
        if y is not None:
            gbuffer_latents = torch.cat([y, gbuffer_latents], dim=1)
        return {"y": gbuffer_latents}


def expand_patch_embedding(pipe, num_gbuffers):
    """
    Expand DiT's patch_embedding Conv3d to accept additional G-buffer latent channels.
    New channels are zero-initialized so the model starts equivalent to the original.
    """
    if num_gbuffers <= 0:
        return
    dit = pipe.dit
    old_conv = dit.patch_embedding
    old_weight = old_conv.weight  # [out_channels, in_channels, *kernel_size]
    old_bias = old_conv.bias

    extra_channels = num_gbuffers * 16  # 16 VAE latent channels per modality

    # Skip if already expanded (e.g., loading from a GBuffer-trained checkpoint)
    # If current in_channels > extra_channels, it already includes base + gbuffer channels
    if old_weight.shape[1] > extra_channels:
        print(f"patch_embedding already has {old_weight.shape[1]} input channels (> {extra_channels} extra), already expanded. Skipping.")
        return

    new_in_dim = old_weight.shape[1] + extra_channels

    new_conv = nn.Conv3d(
        new_in_dim, old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=old_bias is not None,
    )
    with torch.no_grad():
        new_conv.weight.zero_()
        new_conv.weight[:, :old_weight.shape[1]] = old_weight
        if old_bias is not None:
            new_conv.bias.copy_(old_bias)

    dit.patch_embedding = new_conv.to(dtype=old_weight.dtype, device=old_weight.device)
    dit.in_dim = new_in_dim
    print(f"Expanded patch_embedding: {old_weight.shape[1]} -> {new_in_dim} input channels (+{extra_channels} for {num_gbuffers} G-buffers)")


def inject_gbuffer_unit(pipe):
    """
    Insert WanVideoUnit_GBufferEncoder into the pipeline's unit list,
    after ImageEmbedderFused and before FunControl.
    """
    # Skip if already injected
    for u in pipe.units:
        if type(u).__name__ == "WanVideoUnit_GBufferEncoder":
            print("WanVideoUnit_GBufferEncoder already present, skipping injection.")
            return
    unit = WanVideoUnit_GBufferEncoder()
    insert_idx = None
    for i, u in enumerate(pipe.units):
        if type(u).__name__ == "WanVideoUnit_ImageEmbedderFused":
            insert_idx = i + 1
            break
    if insert_idx is not None:
        pipe.units.insert(insert_idx, unit)
    else:
        # Fallback: insert before the last few units
        pipe.units.insert(-3, unit)
    print(f"Injected WanVideoUnit_GBufferEncoder at position {insert_idx}")