Spaces:
Runtime error
Runtime error
| """ | |
| Advanced 8-bit quantization helpers for the Wan 2.2 image-to-video pipeline. | |
| The original project uses an optimization routine that keeps most of the model | |
| in bf16 while applying a lighter weight-only quantization. This module pushes | |
| the memory savings further by aggressively quantizing the heavy transformer | |
| components (and text encoder) to 8‑bit right after LoRA fusion. The utilities | |
| gracefully fall back to the lighter flow when an optional backend is missing, | |
| but the expectation is that `torchao` is available via the project | |
| requirements. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Any, Callable, ParamSpec | |
| from pathlib import Path | |
| import warnings | |
| import sys | |
| # The Wan pipeline does not benefit from FP8 paths yet and on some setups they | |
| # even increase memory usage, so we keep those turned off just in case. | |
| os.environ.setdefault("TORCHINDUCTOR_DISABLE_FP8", "1") | |
| os.environ.setdefault("CUDA_DISABLE_FP8", "1") | |
| os.environ.setdefault("TORCHINDUCTOR_DEBUG", "0") | |
| os.environ.setdefault("TORCH_LOGS", "") | |
| import torch | |
| CURRENT_DIR = Path(__file__).resolve().parent | |
| REFERENCE_DIR = CURRENT_DIR.parent / "wan_more333" | |
| if str(REFERENCE_DIR) not in sys.path: | |
| sys.path.insert(0, str(REFERENCE_DIR)) | |
| warnings.filterwarnings( | |
| "ignore", | |
| message="Loading adapter weights from state_dict led to unexpected keys found in the model", | |
| ) | |
| LORA_FILENAME = "lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors" | |
| def _resolve_lora_directory() -> Path: | |
| explicit_dir = os.getenv("WAN_LORA_DIR") | |
| candidates: list[Path] = [] | |
| if explicit_dir: | |
| candidates.append(Path(explicit_dir)) | |
| candidates.extend( | |
| [ | |
| CURRENT_DIR / "models", | |
| REFERENCE_DIR / "models", | |
| Path.home() / ".cache" / "huggingface" / "hub", | |
| ] | |
| ) | |
| for directory in candidates: | |
| if not directory.exists(): | |
| continue | |
| try: | |
| match = next(directory.rglob(LORA_FILENAME)) | |
| return match.parent | |
| except StopIteration: | |
| continue | |
| raise FileNotFoundError( | |
| "Required LoRA weights not found locally. Place 'lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors' " | |
| "inside projects/wan_moreimages/models/ or set WAN_LORA_DIR to the folder containing it." | |
| ) | |
| try: | |
| from torchao.quantization import Int8WeightOnlyConfig, quantize_ | |
| _TORCHAO_AVAILABLE = True | |
| except Exception: # pragma: no cover - optional dependency | |
| Int8WeightOnlyConfig = None # type: ignore[assignment] | |
| quantize_ = None # type: ignore[assignment] | |
| _TORCHAO_AVAILABLE = False | |
| try: # pragma: no cover - bitsandbytes is an optional extra | |
| import bitsandbytes as bnb | |
| _BITSANDBYTES_AVAILABLE = True | |
| except Exception: # pragma: no cover - optional dependency | |
| bnb = None # type: ignore[assignment] | |
| _BITSANDBYTES_AVAILABLE = False | |
| P = ParamSpec("P") | |
| def _safe_to_bf16(module: torch.nn.Module) -> None: | |
| try: | |
| module.to(torch.bfloat16) | |
| except Exception: | |
| pass | |
| def _quantize_with_torchao(module: torch.nn.Module, module_name: str) -> bool: | |
| if not _TORCHAO_AVAILABLE: | |
| return False | |
| try: | |
| quantize_(module, Int8WeightOnlyConfig()) # type: ignore[arg-type] | |
| print(f"[INT8] torchao weight-only quantization applied to {module_name}") | |
| return True | |
| except Exception as exc: | |
| print(f"[INT8][WARN] torchao quantization failed for {module_name}: {exc}") | |
| return False | |
| def _convert_linear_to_8bit_lt(linear: torch.nn.Linear) -> torch.nn.Module: | |
| assert _BITSANDBYTES_AVAILABLE and bnb is not None | |
| device = linear.weight.device | |
| bias = linear.bias is not None | |
| eightbit = bnb.nn.Linear8bitLt( | |
| linear.in_features, | |
| linear.out_features, | |
| bias=bias, | |
| has_fp16_weights=False, | |
| device=device, | |
| ) | |
| eightbit.weight.data.copy_(linear.weight.data) | |
| if bias: | |
| eightbit.bias = torch.nn.Parameter(linear.bias.data.to(device)) | |
| return eightbit | |
| def _quantize_with_bitsandbytes(module: torch.nn.Module, module_name: str) -> bool: | |
| if not _BITSANDBYTES_AVAILABLE: | |
| return False | |
| converted_any = False | |
| def _recursive_swap(parent: torch.nn.Module) -> None: | |
| nonlocal converted_any | |
| for name, child in list(parent.named_children()): | |
| if isinstance(child, torch.nn.Linear): | |
| converted_child = _convert_linear_to_8bit_lt(child) | |
| parent._modules[name] = converted_child | |
| converted_any = True | |
| else: | |
| _recursive_swap(child) | |
| try: | |
| _recursive_swap(module) | |
| if converted_any: | |
| print(f"[INT8] bitsandbytes Linear8bitLt swap applied to {module_name}") | |
| else: | |
| print(f"[INT8][WARN] No linear layers found in {module_name} for 8-bit swap") | |
| except Exception as exc: | |
| print(f"[INT8][WARN] bitsandbytes swap failed for {module_name}: {exc}") | |
| return False | |
| return converted_any | |
| def _quantize_module(module: torch.nn.Module, module_name: str) -> None: | |
| if _quantize_with_torchao(module, module_name): | |
| return | |
| if _quantize_with_bitsandbytes(module, module_name): | |
| return | |
| print(f"[INT8][WARN] 8-bit quantization skipped for {module_name} (no backend)") | |
| def optimize_pipeline_int8( | |
| pipeline: Callable[P, Any], | |
| *args: P.args, | |
| **kwargs: P.kwargs, | |
| ) -> None: | |
| """Apply bf16 casting + 8-bit quantization while keeping weights off GPU.""" | |
| torch.set_float32_matmul_precision("high") | |
| if hasattr(pipeline, "reset_device_map"): | |
| pipeline.reset_device_map() | |
| pipeline.to("cpu") | |
| # This LoRA fusion part remains the same | |
| pipeline.load_lora_weights( | |
| "Kijai/WanVideo_comfy", | |
| weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", | |
| adapter_name="lightx2v" | |
| ) | |
| kwargs_lora = {} | |
| kwargs_lora["load_into_transformer_2"] = True | |
| pipeline.load_lora_weights( | |
| "Kijai/WanVideo_comfy", | |
| weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", | |
| adapter_name="lightx2v_2", **kwargs_lora | |
| ) | |
| pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0]) | |
| pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"]) | |
| pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"]) | |
| pipeline.unload_lora_weights() | |
| _safe_to_bf16(pipeline.transformer) | |
| _safe_to_bf16(pipeline.transformer_2) | |
| _quantize_module(pipeline.transformer, "transformer") | |
| _quantize_module(pipeline.transformer_2, "transformer_2") | |
| for component_name in ("text_encoder", "vae", "vae.decoder", "vae.encoder"): | |
| module = pipeline | |
| try: | |
| for attr in component_name.split("."): | |
| module = getattr(module, attr) | |
| _safe_to_bf16(module) | |
| _quantize_module(module, component_name) | |
| except AttributeError: | |
| continue | |
| try: | |
| _safe_to_bf16(pipeline.text_encoder_2) # type: ignore[attr-defined] | |
| _quantize_module(pipeline.text_encoder_2, "text_encoder_2") | |
| except AttributeError: | |
| pass | |
| gc = __import__("gc") | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if hasattr(pipeline, "enable_sequential_cpu_offload"): | |
| pipeline.enable_sequential_cpu_offload() | |
| elif hasattr(pipeline, "enable_model_cpu_offload"): | |
| pipeline.enable_model_cpu_offload() | |
| else: | |
| print("[WARN] Diffusers version lacks CPU offload helpers; keeping pipeline on CPU.") | |