wan_more_images / optimization_quantized.py
kylin0421
opt
c0c5d65
"""
Advanced 8-bit quantization helpers for the Wan 2.2 image-to-video pipeline.
The original project uses an optimization routine that keeps most of the model
in bf16 while applying a lighter weight-only quantization. This module pushes
the memory savings further by aggressively quantizing the heavy transformer
components (and text encoder) to 8‑bit right after LoRA fusion. The utilities
gracefully fall back to the lighter flow when an optional backend is missing,
but the expectation is that `torchao` is available via the project
requirements.
"""
from __future__ import annotations
import os
from typing import Any, Callable, ParamSpec
from pathlib import Path
import warnings
import sys
# The Wan pipeline does not benefit from FP8 paths yet and on some setups they
# even increase memory usage, so we keep those turned off just in case.
os.environ.setdefault("TORCHINDUCTOR_DISABLE_FP8", "1")
os.environ.setdefault("CUDA_DISABLE_FP8", "1")
os.environ.setdefault("TORCHINDUCTOR_DEBUG", "0")
os.environ.setdefault("TORCH_LOGS", "")
import torch
CURRENT_DIR = Path(__file__).resolve().parent
REFERENCE_DIR = CURRENT_DIR.parent / "wan_more333"
if str(REFERENCE_DIR) not in sys.path:
sys.path.insert(0, str(REFERENCE_DIR))
warnings.filterwarnings(
"ignore",
message="Loading adapter weights from state_dict led to unexpected keys found in the model",
)
LORA_FILENAME = "lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
def _resolve_lora_directory() -> Path:
explicit_dir = os.getenv("WAN_LORA_DIR")
candidates: list[Path] = []
if explicit_dir:
candidates.append(Path(explicit_dir))
candidates.extend(
[
CURRENT_DIR / "models",
REFERENCE_DIR / "models",
Path.home() / ".cache" / "huggingface" / "hub",
]
)
for directory in candidates:
if not directory.exists():
continue
try:
match = next(directory.rglob(LORA_FILENAME))
return match.parent
except StopIteration:
continue
raise FileNotFoundError(
"Required LoRA weights not found locally. Place 'lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors' "
"inside projects/wan_moreimages/models/ or set WAN_LORA_DIR to the folder containing it."
)
try:
from torchao.quantization import Int8WeightOnlyConfig, quantize_
_TORCHAO_AVAILABLE = True
except Exception: # pragma: no cover - optional dependency
Int8WeightOnlyConfig = None # type: ignore[assignment]
quantize_ = None # type: ignore[assignment]
_TORCHAO_AVAILABLE = False
try: # pragma: no cover - bitsandbytes is an optional extra
import bitsandbytes as bnb
_BITSANDBYTES_AVAILABLE = True
except Exception: # pragma: no cover - optional dependency
bnb = None # type: ignore[assignment]
_BITSANDBYTES_AVAILABLE = False
P = ParamSpec("P")
def _safe_to_bf16(module: torch.nn.Module) -> None:
try:
module.to(torch.bfloat16)
except Exception:
pass
def _quantize_with_torchao(module: torch.nn.Module, module_name: str) -> bool:
if not _TORCHAO_AVAILABLE:
return False
try:
quantize_(module, Int8WeightOnlyConfig()) # type: ignore[arg-type]
print(f"[INT8] torchao weight-only quantization applied to {module_name}")
return True
except Exception as exc:
print(f"[INT8][WARN] torchao quantization failed for {module_name}: {exc}")
return False
def _convert_linear_to_8bit_lt(linear: torch.nn.Linear) -> torch.nn.Module:
assert _BITSANDBYTES_AVAILABLE and bnb is not None
device = linear.weight.device
bias = linear.bias is not None
eightbit = bnb.nn.Linear8bitLt(
linear.in_features,
linear.out_features,
bias=bias,
has_fp16_weights=False,
device=device,
)
eightbit.weight.data.copy_(linear.weight.data)
if bias:
eightbit.bias = torch.nn.Parameter(linear.bias.data.to(device))
return eightbit
def _quantize_with_bitsandbytes(module: torch.nn.Module, module_name: str) -> bool:
if not _BITSANDBYTES_AVAILABLE:
return False
converted_any = False
def _recursive_swap(parent: torch.nn.Module) -> None:
nonlocal converted_any
for name, child in list(parent.named_children()):
if isinstance(child, torch.nn.Linear):
converted_child = _convert_linear_to_8bit_lt(child)
parent._modules[name] = converted_child
converted_any = True
else:
_recursive_swap(child)
try:
_recursive_swap(module)
if converted_any:
print(f"[INT8] bitsandbytes Linear8bitLt swap applied to {module_name}")
else:
print(f"[INT8][WARN] No linear layers found in {module_name} for 8-bit swap")
except Exception as exc:
print(f"[INT8][WARN] bitsandbytes swap failed for {module_name}: {exc}")
return False
return converted_any
def _quantize_module(module: torch.nn.Module, module_name: str) -> None:
if _quantize_with_torchao(module, module_name):
return
if _quantize_with_bitsandbytes(module, module_name):
return
print(f"[INT8][WARN] 8-bit quantization skipped for {module_name} (no backend)")
def optimize_pipeline_int8(
pipeline: Callable[P, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Apply bf16 casting + 8-bit quantization while keeping weights off GPU."""
torch.set_float32_matmul_precision("high")
if hasattr(pipeline, "reset_device_map"):
pipeline.reset_device_map()
pipeline.to("cpu")
# This LoRA fusion part remains the same
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v"
)
kwargs_lora = {}
kwargs_lora["load_into_transformer_2"] = True
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2", **kwargs_lora
)
pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
pipeline.unload_lora_weights()
_safe_to_bf16(pipeline.transformer)
_safe_to_bf16(pipeline.transformer_2)
_quantize_module(pipeline.transformer, "transformer")
_quantize_module(pipeline.transformer_2, "transformer_2")
for component_name in ("text_encoder", "vae", "vae.decoder", "vae.encoder"):
module = pipeline
try:
for attr in component_name.split("."):
module = getattr(module, attr)
_safe_to_bf16(module)
_quantize_module(module, component_name)
except AttributeError:
continue
try:
_safe_to_bf16(pipeline.text_encoder_2) # type: ignore[attr-defined]
_quantize_module(pipeline.text_encoder_2, "text_encoder_2")
except AttributeError:
pass
gc = __import__("gc")
gc.collect()
torch.cuda.empty_cache()
if hasattr(pipeline, "enable_sequential_cpu_offload"):
pipeline.enable_sequential_cpu_offload()
elif hasattr(pipeline, "enable_model_cpu_offload"):
pipeline.enable_model_cpu_offload()
else:
print("[WARN] Diffusers version lacks CPU offload helpers; keeping pipeline on CPU.")