""" Advanced 8-bit quantization helpers for the Wan 2.2 image-to-video pipeline. The original project uses an optimization routine that keeps most of the model in bf16 while applying a lighter weight-only quantization. This module pushes the memory savings further by aggressively quantizing the heavy transformer components (and text encoder) to 8‑bit right after LoRA fusion. The utilities gracefully fall back to the lighter flow when an optional backend is missing, but the expectation is that `torchao` is available via the project requirements. """ from __future__ import annotations import os from typing import Any, Callable, ParamSpec from pathlib import Path import warnings import sys # The Wan pipeline does not benefit from FP8 paths yet and on some setups they # even increase memory usage, so we keep those turned off just in case. os.environ.setdefault("TORCHINDUCTOR_DISABLE_FP8", "1") os.environ.setdefault("CUDA_DISABLE_FP8", "1") os.environ.setdefault("TORCHINDUCTOR_DEBUG", "0") os.environ.setdefault("TORCH_LOGS", "") import torch CURRENT_DIR = Path(__file__).resolve().parent REFERENCE_DIR = CURRENT_DIR.parent / "wan_more333" if str(REFERENCE_DIR) not in sys.path: sys.path.insert(0, str(REFERENCE_DIR)) warnings.filterwarnings( "ignore", message="Loading adapter weights from state_dict led to unexpected keys found in the model", ) LORA_FILENAME = "lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors" def _resolve_lora_directory() -> Path: explicit_dir = os.getenv("WAN_LORA_DIR") candidates: list[Path] = [] if explicit_dir: candidates.append(Path(explicit_dir)) candidates.extend( [ CURRENT_DIR / "models", REFERENCE_DIR / "models", Path.home() / ".cache" / "huggingface" / "hub", ] ) for directory in candidates: if not directory.exists(): continue try: match = next(directory.rglob(LORA_FILENAME)) return match.parent except StopIteration: continue raise FileNotFoundError( "Required LoRA weights not found locally. Place 'lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors' " "inside projects/wan_moreimages/models/ or set WAN_LORA_DIR to the folder containing it." ) try: from torchao.quantization import Int8WeightOnlyConfig, quantize_ _TORCHAO_AVAILABLE = True except Exception: # pragma: no cover - optional dependency Int8WeightOnlyConfig = None # type: ignore[assignment] quantize_ = None # type: ignore[assignment] _TORCHAO_AVAILABLE = False try: # pragma: no cover - bitsandbytes is an optional extra import bitsandbytes as bnb _BITSANDBYTES_AVAILABLE = True except Exception: # pragma: no cover - optional dependency bnb = None # type: ignore[assignment] _BITSANDBYTES_AVAILABLE = False P = ParamSpec("P") def _safe_to_bf16(module: torch.nn.Module) -> None: try: module.to(torch.bfloat16) except Exception: pass def _quantize_with_torchao(module: torch.nn.Module, module_name: str) -> bool: if not _TORCHAO_AVAILABLE: return False try: quantize_(module, Int8WeightOnlyConfig()) # type: ignore[arg-type] print(f"[INT8] torchao weight-only quantization applied to {module_name}") return True except Exception as exc: print(f"[INT8][WARN] torchao quantization failed for {module_name}: {exc}") return False def _convert_linear_to_8bit_lt(linear: torch.nn.Linear) -> torch.nn.Module: assert _BITSANDBYTES_AVAILABLE and bnb is not None device = linear.weight.device bias = linear.bias is not None eightbit = bnb.nn.Linear8bitLt( linear.in_features, linear.out_features, bias=bias, has_fp16_weights=False, device=device, ) eightbit.weight.data.copy_(linear.weight.data) if bias: eightbit.bias = torch.nn.Parameter(linear.bias.data.to(device)) return eightbit def _quantize_with_bitsandbytes(module: torch.nn.Module, module_name: str) -> bool: if not _BITSANDBYTES_AVAILABLE: return False converted_any = False def _recursive_swap(parent: torch.nn.Module) -> None: nonlocal converted_any for name, child in list(parent.named_children()): if isinstance(child, torch.nn.Linear): converted_child = _convert_linear_to_8bit_lt(child) parent._modules[name] = converted_child converted_any = True else: _recursive_swap(child) try: _recursive_swap(module) if converted_any: print(f"[INT8] bitsandbytes Linear8bitLt swap applied to {module_name}") else: print(f"[INT8][WARN] No linear layers found in {module_name} for 8-bit swap") except Exception as exc: print(f"[INT8][WARN] bitsandbytes swap failed for {module_name}: {exc}") return False return converted_any def _quantize_module(module: torch.nn.Module, module_name: str) -> None: if _quantize_with_torchao(module, module_name): return if _quantize_with_bitsandbytes(module, module_name): return print(f"[INT8][WARN] 8-bit quantization skipped for {module_name} (no backend)") def optimize_pipeline_int8( pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs, ) -> None: """Apply bf16 casting + 8-bit quantization while keeping weights off GPU.""" torch.set_float32_matmul_precision("high") if hasattr(pipeline, "reset_device_map"): pipeline.reset_device_map() pipeline.to("cpu") # This LoRA fusion part remains the same pipeline.load_lora_weights( "Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v" ) kwargs_lora = {} kwargs_lora["load_into_transformer_2"] = True pipeline.load_lora_weights( "Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v_2", **kwargs_lora ) pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0]) pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"]) pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"]) pipeline.unload_lora_weights() _safe_to_bf16(pipeline.transformer) _safe_to_bf16(pipeline.transformer_2) _quantize_module(pipeline.transformer, "transformer") _quantize_module(pipeline.transformer_2, "transformer_2") for component_name in ("text_encoder", "vae", "vae.decoder", "vae.encoder"): module = pipeline try: for attr in component_name.split("."): module = getattr(module, attr) _safe_to_bf16(module) _quantize_module(module, component_name) except AttributeError: continue try: _safe_to_bf16(pipeline.text_encoder_2) # type: ignore[attr-defined] _quantize_module(pipeline.text_encoder_2, "text_encoder_2") except AttributeError: pass gc = __import__("gc") gc.collect() torch.cuda.empty_cache() if hasattr(pipeline, "enable_sequential_cpu_offload"): pipeline.enable_sequential_cpu_offload() elif hasattr(pipeline, "enable_model_cpu_offload"): pipeline.enable_model_cpu_offload() else: print("[WARN] Diffusers version lacks CPU offload helpers; keeping pipeline on CPU.")