File size: 5,708 Bytes
30d1371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9080cb5
30d1371
 
 
 
 
 
 
 
 
6ee6428
 
 
 
30d1371
 
 
 
e49ba69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import os
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from huggingface_hub import hf_hub_download

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import drain_module_parameters
from optimization_utils import zerogpu_compiled_from_serializable_dict
from optimization_utils import ZeroGPUCompiledModel


P = ParamSpec('P')

LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: LATENT_FRAMES_DIM,
        3: 2 * LATENT_PATCHED_HEIGHT_DIM,
        4: 2 * LATENT_PATCHED_WIDTH_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def _strtobool(v: str | None, default: bool = True) -> bool:
    if v is None:
        return default
    return v.strip().lower() in ("1", "true", "yes", "y", "on")


def _load_compiled_pt(path: str):
    """
    Load either:
    - a serialized dict produced by to_serializable_dict()  (format zerogpu_aoti_v1), or
    - an old-style pickled ZeroGPUCompiledModel.
    """
    obj = torch.load(path, map_location="cpu", weights_only=False)

    # New format: dict payload
    if isinstance(obj, dict) and obj.get("format") == "zerogpu_aoti_v1":
        return zerogpu_compiled_from_serializable_dict(obj)

    # Old format: direct object
    if isinstance(obj, ZeroGPUCompiledModel):
        return obj

    raise ValueError(
        f"Unsupported compiled transformer file format at {path}. "
        f"Got type={type(obj)} keys={list(obj.keys()) if isinstance(obj, dict) else None}"
    )


def load_compiled_transformers_from_hub(
    repo_id: str,
    filename_1: str = "compiled_transformer_1.pt",
    filename_2: str = "compiled_transformer_2.pt",
):
    """
    Charge les artefacts précompilés depuis le Hub.

    IMPORTANT:
    Les fichiers attendus sont ceux que tu exportes via to_serializable_dict()
    (format 'zerogpu_aoti_v1') OU un pickle direct de ZeroGPUCompiledModel.
    """
    path_1 = hf_hub_download(repo_id=repo_id, filename=filename_1)
    path_2 = hf_hub_download(repo_id=repo_id, filename=filename_2)

    compiled_1 = _load_compiled_pt(path_1)
    compiled_2 = _load_compiled_pt(path_2)
    return compiled_1, compiled_2


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
    @spaces.GPU(duration=1500)
    def compile_transformer():
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy",
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
            adapter_name="lightx2v",
        )
        kwargs_lora = {"load_into_transformer_2": True}
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy",
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
            adapter_name="lightx2v_2",
            **kwargs_lora,
        )
        pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
        pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
        pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
        pipeline.unload_lora_weights()

        with capture_component_call(pipeline, "transformer") as call:
            pipeline(*args, **kwargs)

        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())

        exported_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )
        exported_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )

        compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
        compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
        return compiled_1, compiled_2

    # Text encoder quant (inchangé)
    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())

    use_precompiled = False
    precompiled_repo = os.getenv("WAN_PRECOMPILED_REPO", "Fabrice-TIERCELIN/Wan_2.2_compiled")

    if use_precompiled:
        compiled_transformer_1, compiled_transformer_2 = load_compiled_transformers_from_hub(
            repo_id=precompiled_repo
        )
    else:
        compiled_transformer_1, compiled_transformer_2 = compile_transformer()

    # expose for downloads
    COMPILED_TRANSFORMER_1 = compiled_transformer_1
    COMPILED_TRANSFORMER_2 = compiled_transformer_2

    pipeline.transformer.forward = compiled_transformer_1
    drain_module_parameters(pipeline.transformer)

    pipeline.transformer_2.forward = compiled_transformer_2
    drain_module_parameters(pipeline.transformer_2)