File size: 5,981 Bytes
b617d2a
 
 
 
 
 
43ce3f6
 
b617d2a
 
43ce3f6
 
 
 
 
 
 
 
 
 
 
b617d2a
 
 
43ce3f6
 
 
 
 
 
 
b617d2a
 
43ce3f6
 
 
 
e988808
b617d2a
 
 
 
 
 
 
 
 
 
 
43ce3f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b617d2a
43ce3f6
b617d2a
43ce3f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b617d2a
43ce3f6
b617d2a
 
43ce3f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b617d2a
43ce3f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e988808
43ce3f6
 
 
 
 
 
 
 
 
 
e988808
43ce3f6
 
e988808
43ce3f6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""

"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import os
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from huggingface_hub import hf_hub_download

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import drain_module_parameters
from optimization_utils import ZeroGPUCompiledModelFromDict  # NEW


P = ParamSpec('P')

# Expose compiled models so app.py can offer them for download
COMPILED_TRANSFORMER_1 = None
COMPILED_TRANSFORMER_2 = None

LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: LATENT_FRAMES_DIM,
        3: 2 * LATENT_PATCHED_HEIGHT_DIM,
        4: 2 * LATENT_PATCHED_WIDTH_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def load_compiled_transformers_from_hub(

    repo_id: str,

    filename_1: str = "compiled_transformer_1.pt",

    filename_2: str = "compiled_transformer_2.pt",

    device: str = "cuda",

):
    """

    Loads the payload dicts (created via ZeroGPUCompiledModel.to_serializable_dict() and torch.save)

    and rebuilds callable models that will move constants to CUDA on first call.

    """
    path_1 = hf_hub_download(repo_id=repo_id, filename=filename_1)
    path_2 = hf_hub_download(repo_id=repo_id, filename=filename_2)

    payload_1 = torch.load(path_1, map_location="cpu", weights_only=False)
    payload_2 = torch.load(path_2, map_location="cpu", weights_only=False)

    if not isinstance(payload_1, dict) or not isinstance(payload_2, dict):
        raise TypeError("Precompiled files are not payload dicts. Please re-export them with to_serializable_dict().")

    compiled_1 = ZeroGPUCompiledModelFromDict(payload_1, device=device)
    compiled_2 = ZeroGPUCompiledModelFromDict(payload_2, device=device)
    return compiled_1, compiled_2


def _strtobool(v: str | None, default: bool = True) -> bool:
    if v is None:
        return default
    return v.strip().lower() in ("1", "true", "yes", "y", "on")


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
    global COMPILED_TRANSFORMER_1, COMPILED_TRANSFORMER_2

    @spaces.GPU(duration=1500)
    def compile_transformer():
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy",
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
            adapter_name="lightx2v",
        )
        kwargs_lora = {"load_into_transformer_2": True}
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy",
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
            adapter_name="lightx2v_2",
            **kwargs_lora,
        )
        pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
        pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
        pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
        pipeline.unload_lora_weights()

        with capture_component_call(pipeline, "transformer") as call:
            pipeline(*args, **kwargs)

        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())

        exported_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )
        exported_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )

        compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
        compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
        return compiled_1, compiled_2

    # Quantize text encoder
    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())

    use_precompiled = False
    precompiled_repo = os.getenv("WAN_PRECOMPILED_REPO", "Fabrice-TIERCELIN/Wan_2.2_compiled")

    if use_precompiled:
        try:
            compiled_transformer_1, compiled_transformer_2 = load_compiled_transformers_from_hub(
                repo_id=precompiled_repo,
                device="cuda",
            )
        except Exception as e:
            # fallback if payload format is wrong / outdated
            print(f"[WARN] Failed to load precompiled artifacts ({e}). Falling back to GPU compilation.")
            compiled_transformer_1, compiled_transformer_2 = compile_transformer()
    else:
        compiled_transformer_1, compiled_transformer_2 = compile_transformer()

    # expose for downloads
    COMPILED_TRANSFORMER_1 = compiled_transformer_1
    COMPILED_TRANSFORMER_2 = compiled_transformer_2

    pipeline.transformer.forward = compiled_transformer_1
    drain_module_parameters(pipeline.transformer)

    pipeline.transformer_2.forward = compiled_transformer_2
    drain_module_parameters(pipeline.transformer_2)