File size: 5,981 Bytes
b617d2a 43ce3f6 b617d2a 43ce3f6 b617d2a 43ce3f6 b617d2a 43ce3f6 e988808 b617d2a 43ce3f6 b617d2a 43ce3f6 b617d2a 43ce3f6 b617d2a 43ce3f6 b617d2a 43ce3f6 b617d2a 43ce3f6 e988808 43ce3f6 e988808 43ce3f6 e988808 43ce3f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import os
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from huggingface_hub import hf_hub_download
from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import drain_module_parameters
from optimization_utils import ZeroGPUCompiledModelFromDict # NEW
P = ParamSpec('P')
# Expose compiled models so app.py can offer them for download
COMPILED_TRANSFORMER_1 = None
COMPILED_TRANSFORMER_2 = None
LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
2: LATENT_FRAMES_DIM,
3: 2 * LATENT_PATCHED_HEIGHT_DIM,
4: 2 * LATENT_PATCHED_WIDTH_DIM,
},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def load_compiled_transformers_from_hub(
repo_id: str,
filename_1: str = "compiled_transformer_1.pt",
filename_2: str = "compiled_transformer_2.pt",
device: str = "cuda",
):
"""
Loads the payload dicts (created via ZeroGPUCompiledModel.to_serializable_dict() and torch.save)
and rebuilds callable models that will move constants to CUDA on first call.
"""
path_1 = hf_hub_download(repo_id=repo_id, filename=filename_1)
path_2 = hf_hub_download(repo_id=repo_id, filename=filename_2)
payload_1 = torch.load(path_1, map_location="cpu", weights_only=False)
payload_2 = torch.load(path_2, map_location="cpu", weights_only=False)
if not isinstance(payload_1, dict) or not isinstance(payload_2, dict):
raise TypeError("Precompiled files are not payload dicts. Please re-export them with to_serializable_dict().")
compiled_1 = ZeroGPUCompiledModelFromDict(payload_1, device=device)
compiled_2 = ZeroGPUCompiledModelFromDict(payload_2, device=device)
return compiled_1, compiled_2
def _strtobool(v: str | None, default: bool = True) -> bool:
if v is None:
return default
return v.strip().lower() in ("1", "true", "yes", "y", "on")
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
global COMPILED_TRANSFORMER_1, COMPILED_TRANSFORMER_2
@spaces.GPU(duration=1500)
def compile_transformer():
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v",
)
kwargs_lora = {"load_into_transformer_2": True}
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2",
**kwargs_lora,
)
pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipeline.unload_lora_weights()
with capture_component_call(pipeline, "transformer") as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
exported_1 = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
exported_2 = torch.export.export(
mod=pipeline.transformer_2,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
return compiled_1, compiled_2
# Quantize text encoder
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
use_precompiled = False
precompiled_repo = os.getenv("WAN_PRECOMPILED_REPO", "Fabrice-TIERCELIN/Wan_2.2_compiled")
if use_precompiled:
try:
compiled_transformer_1, compiled_transformer_2 = load_compiled_transformers_from_hub(
repo_id=precompiled_repo,
device="cuda",
)
except Exception as e:
# fallback if payload format is wrong / outdated
print(f"[WARN] Failed to load precompiled artifacts ({e}). Falling back to GPU compilation.")
compiled_transformer_1, compiled_transformer_2 = compile_transformer()
else:
compiled_transformer_1, compiled_transformer_2 = compile_transformer()
# expose for downloads
COMPILED_TRANSFORMER_1 = compiled_transformer_1
COMPILED_TRANSFORMER_2 = compiled_transformer_2
pipeline.transformer.forward = compiled_transformer_1
drain_module_parameters(pipeline.transformer)
pipeline.transformer_2.forward = compiled_transformer_2
drain_module_parameters(pipeline.transformer_2) |