Spaces:
Runtime error
Runtime error
File size: 7,685 Bytes
3bd0d36 1d21e35 3bd0d36 1d21e35 3bd0d36 1d21e35 3bd0d36 c0c5d65 3bd0d36 c0c5d65 3bd0d36 c0c5d65 3bd0d36 c0c5d65 3bd0d36 1d21e35 c1baeff 1d21e35 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """
Advanced 8-bit quantization helpers for the Wan 2.2 image-to-video pipeline.
The original project uses an optimization routine that keeps most of the model
in bf16 while applying a lighter weight-only quantization. This module pushes
the memory savings further by aggressively quantizing the heavy transformer
components (and text encoder) to 8‑bit right after LoRA fusion. The utilities
gracefully fall back to the lighter flow when an optional backend is missing,
but the expectation is that `torchao` is available via the project
requirements.
"""
from __future__ import annotations
import os
from typing import Any, Callable, ParamSpec
from pathlib import Path
import warnings
import sys
# The Wan pipeline does not benefit from FP8 paths yet and on some setups they
# even increase memory usage, so we keep those turned off just in case.
os.environ.setdefault("TORCHINDUCTOR_DISABLE_FP8", "1")
os.environ.setdefault("CUDA_DISABLE_FP8", "1")
os.environ.setdefault("TORCHINDUCTOR_DEBUG", "0")
os.environ.setdefault("TORCH_LOGS", "")
import torch
CURRENT_DIR = Path(__file__).resolve().parent
REFERENCE_DIR = CURRENT_DIR.parent / "wan_more333"
if str(REFERENCE_DIR) not in sys.path:
sys.path.insert(0, str(REFERENCE_DIR))
warnings.filterwarnings(
"ignore",
message="Loading adapter weights from state_dict led to unexpected keys found in the model",
)
LORA_FILENAME = "lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
def _resolve_lora_directory() -> Path:
explicit_dir = os.getenv("WAN_LORA_DIR")
candidates: list[Path] = []
if explicit_dir:
candidates.append(Path(explicit_dir))
candidates.extend(
[
CURRENT_DIR / "models",
REFERENCE_DIR / "models",
Path.home() / ".cache" / "huggingface" / "hub",
]
)
for directory in candidates:
if not directory.exists():
continue
try:
match = next(directory.rglob(LORA_FILENAME))
return match.parent
except StopIteration:
continue
raise FileNotFoundError(
"Required LoRA weights not found locally. Place 'lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors' "
"inside projects/wan_moreimages/models/ or set WAN_LORA_DIR to the folder containing it."
)
try:
from torchao.quantization import Int8WeightOnlyConfig, quantize_
_TORCHAO_AVAILABLE = True
except Exception: # pragma: no cover - optional dependency
Int8WeightOnlyConfig = None # type: ignore[assignment]
quantize_ = None # type: ignore[assignment]
_TORCHAO_AVAILABLE = False
try: # pragma: no cover - bitsandbytes is an optional extra
import bitsandbytes as bnb
_BITSANDBYTES_AVAILABLE = True
except Exception: # pragma: no cover - optional dependency
bnb = None # type: ignore[assignment]
_BITSANDBYTES_AVAILABLE = False
P = ParamSpec("P")
def _safe_to_bf16(module: torch.nn.Module) -> None:
try:
module.to(torch.bfloat16)
except Exception:
pass
def _quantize_with_torchao(module: torch.nn.Module, module_name: str) -> bool:
if not _TORCHAO_AVAILABLE:
return False
try:
quantize_(module, Int8WeightOnlyConfig()) # type: ignore[arg-type]
print(f"[INT8] torchao weight-only quantization applied to {module_name}")
return True
except Exception as exc:
print(f"[INT8][WARN] torchao quantization failed for {module_name}: {exc}")
return False
def _convert_linear_to_8bit_lt(linear: torch.nn.Linear) -> torch.nn.Module:
assert _BITSANDBYTES_AVAILABLE and bnb is not None
device = linear.weight.device
bias = linear.bias is not None
eightbit = bnb.nn.Linear8bitLt(
linear.in_features,
linear.out_features,
bias=bias,
has_fp16_weights=False,
device=device,
)
eightbit.weight.data.copy_(linear.weight.data)
if bias:
eightbit.bias = torch.nn.Parameter(linear.bias.data.to(device))
return eightbit
def _quantize_with_bitsandbytes(module: torch.nn.Module, module_name: str) -> bool:
if not _BITSANDBYTES_AVAILABLE:
return False
converted_any = False
def _recursive_swap(parent: torch.nn.Module) -> None:
nonlocal converted_any
for name, child in list(parent.named_children()):
if isinstance(child, torch.nn.Linear):
converted_child = _convert_linear_to_8bit_lt(child)
parent._modules[name] = converted_child
converted_any = True
else:
_recursive_swap(child)
try:
_recursive_swap(module)
if converted_any:
print(f"[INT8] bitsandbytes Linear8bitLt swap applied to {module_name}")
else:
print(f"[INT8][WARN] No linear layers found in {module_name} for 8-bit swap")
except Exception as exc:
print(f"[INT8][WARN] bitsandbytes swap failed for {module_name}: {exc}")
return False
return converted_any
def _quantize_module(module: torch.nn.Module, module_name: str) -> None:
if _quantize_with_torchao(module, module_name):
return
if _quantize_with_bitsandbytes(module, module_name):
return
print(f"[INT8][WARN] 8-bit quantization skipped for {module_name} (no backend)")
def optimize_pipeline_int8(
pipeline: Callable[P, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Apply bf16 casting + 8-bit quantization while keeping weights off GPU."""
torch.set_float32_matmul_precision("high")
if hasattr(pipeline, "reset_device_map"):
pipeline.reset_device_map()
pipeline.to("cpu")
# This LoRA fusion part remains the same
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v"
)
kwargs_lora = {}
kwargs_lora["load_into_transformer_2"] = True
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2", **kwargs_lora
)
pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipeline.unload_lora_weights()
_safe_to_bf16(pipeline.transformer)
_safe_to_bf16(pipeline.transformer_2)
_quantize_module(pipeline.transformer, "transformer")
_quantize_module(pipeline.transformer_2, "transformer_2")
for component_name in ("text_encoder", "vae", "vae.decoder", "vae.encoder"):
module = pipeline
try:
for attr in component_name.split("."):
module = getattr(module, attr)
_safe_to_bf16(module)
_quantize_module(module, component_name)
except AttributeError:
continue
try:
_safe_to_bf16(pipeline.text_encoder_2) # type: ignore[attr-defined]
_quantize_module(pipeline.text_encoder_2, "text_encoder_2")
except AttributeError:
pass
gc = __import__("gc")
gc.collect()
torch.cuda.empty_cache()
if hasattr(pipeline, "enable_sequential_cpu_offload"):
pipeline.enable_sequential_cpu_offload()
elif hasattr(pipeline, "enable_model_cpu_offload"):
pipeline.enable_model_cpu_offload()
else:
print("[WARN] Diffusers version lacks CPU offload helpers; keeping pipeline on CPU.")
|