Spaces:
Runtime error
Runtime error
File size: 5,708 Bytes
30d1371 9080cb5 30d1371 6ee6428 30d1371 e49ba69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import os
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from huggingface_hub import hf_hub_download
from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import drain_module_parameters
from optimization_utils import zerogpu_compiled_from_serializable_dict
from optimization_utils import ZeroGPUCompiledModel
P = ParamSpec('P')
LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
2: LATENT_FRAMES_DIM,
3: 2 * LATENT_PATCHED_HEIGHT_DIM,
4: 2 * LATENT_PATCHED_WIDTH_DIM,
},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def _strtobool(v: str | None, default: bool = True) -> bool:
if v is None:
return default
return v.strip().lower() in ("1", "true", "yes", "y", "on")
def _load_compiled_pt(path: str):
"""
Load either:
- a serialized dict produced by to_serializable_dict() (format zerogpu_aoti_v1), or
- an old-style pickled ZeroGPUCompiledModel.
"""
obj = torch.load(path, map_location="cpu", weights_only=False)
# New format: dict payload
if isinstance(obj, dict) and obj.get("format") == "zerogpu_aoti_v1":
return zerogpu_compiled_from_serializable_dict(obj)
# Old format: direct object
if isinstance(obj, ZeroGPUCompiledModel):
return obj
raise ValueError(
f"Unsupported compiled transformer file format at {path}. "
f"Got type={type(obj)} keys={list(obj.keys()) if isinstance(obj, dict) else None}"
)
def load_compiled_transformers_from_hub(
repo_id: str,
filename_1: str = "compiled_transformer_1.pt",
filename_2: str = "compiled_transformer_2.pt",
):
"""
Charge les artefacts précompilés depuis le Hub.
IMPORTANT:
Les fichiers attendus sont ceux que tu exportes via to_serializable_dict()
(format 'zerogpu_aoti_v1') OU un pickle direct de ZeroGPUCompiledModel.
"""
path_1 = hf_hub_download(repo_id=repo_id, filename=filename_1)
path_2 = hf_hub_download(repo_id=repo_id, filename=filename_2)
compiled_1 = _load_compiled_pt(path_1)
compiled_2 = _load_compiled_pt(path_2)
return compiled_1, compiled_2
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v",
)
kwargs_lora = {"load_into_transformer_2": True}
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2",
**kwargs_lora,
)
pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipeline.unload_lora_weights()
with capture_component_call(pipeline, "transformer") as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
exported_1 = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
exported_2 = torch.export.export(
mod=pipeline.transformer_2,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
return compiled_1, compiled_2
# Text encoder quant (inchangé)
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
use_precompiled = False
precompiled_repo = os.getenv("WAN_PRECOMPILED_REPO", "Fabrice-TIERCELIN/Wan_2.2_compiled")
if use_precompiled:
compiled_transformer_1, compiled_transformer_2 = load_compiled_transformers_from_hub(
repo_id=precompiled_repo
)
else:
compiled_transformer_1, compiled_transformer_2 = compile_transformer()
# expose for downloads
COMPILED_TRANSFORMER_1 = compiled_transformer_1
COMPILED_TRANSFORMER_2 = compiled_transformer_2
pipeline.transformer.forward = compiled_transformer_1
drain_module_parameters(pipeline.transformer)
pipeline.transformer_2.forward = compiled_transformer_2
drain_module_parameters(pipeline.transformer_2) |