File size: 7,685 Bytes
3bd0d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d21e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd0d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d21e35
3bd0d36
 
1d21e35
 
3bd0d36
 
c0c5d65
3bd0d36
c0c5d65
 
 
3bd0d36
c0c5d65
 
3bd0d36
c0c5d65
 
 
3bd0d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d21e35
 
 
 
 
 
 
c1baeff
1d21e35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Advanced 8-bit quantization helpers for the Wan 2.2 image-to-video pipeline.

The original project uses an optimization routine that keeps most of the model
in bf16 while applying a lighter weight-only quantization.  This module pushes
the memory savings further by aggressively quantizing the heavy transformer
components (and text encoder) to 8‑bit right after LoRA fusion.  The utilities
gracefully fall back to the lighter flow when an optional backend is missing,
but the expectation is that `torchao` is available via the project
requirements.
"""

from __future__ import annotations

import os
from typing import Any, Callable, ParamSpec
from pathlib import Path
import warnings

import sys

# The Wan pipeline does not benefit from FP8 paths yet and on some setups they
# even increase memory usage, so we keep those turned off just in case.
os.environ.setdefault("TORCHINDUCTOR_DISABLE_FP8", "1")
os.environ.setdefault("CUDA_DISABLE_FP8", "1")
os.environ.setdefault("TORCHINDUCTOR_DEBUG", "0")
os.environ.setdefault("TORCH_LOGS", "")

import torch

CURRENT_DIR = Path(__file__).resolve().parent
REFERENCE_DIR = CURRENT_DIR.parent / "wan_more333"
if str(REFERENCE_DIR) not in sys.path:
    sys.path.insert(0, str(REFERENCE_DIR))

warnings.filterwarnings(
    "ignore",
    message="Loading adapter weights from state_dict led to unexpected keys found in the model",
)

LORA_FILENAME = "lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"



def _resolve_lora_directory() -> Path:
    explicit_dir = os.getenv("WAN_LORA_DIR")
    candidates: list[Path] = []
    if explicit_dir:
        candidates.append(Path(explicit_dir))
    candidates.extend(
        [
            CURRENT_DIR / "models",
            REFERENCE_DIR / "models",
            Path.home() / ".cache" / "huggingface" / "hub",
        ]
    )

    for directory in candidates:
        if not directory.exists():
            continue
        try:
            match = next(directory.rglob(LORA_FILENAME))
            return match.parent
        except StopIteration:
            continue

    raise FileNotFoundError(
        "Required LoRA weights not found locally. Place 'lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors' "
        "inside projects/wan_moreimages/models/ or set WAN_LORA_DIR to the folder containing it."
    )


try:
    from torchao.quantization import Int8WeightOnlyConfig, quantize_

    _TORCHAO_AVAILABLE = True
except Exception:  # pragma: no cover - optional dependency
    Int8WeightOnlyConfig = None  # type: ignore[assignment]
    quantize_ = None  # type: ignore[assignment]
    _TORCHAO_AVAILABLE = False

try:  # pragma: no cover - bitsandbytes is an optional extra
    import bitsandbytes as bnb

    _BITSANDBYTES_AVAILABLE = True
except Exception:  # pragma: no cover - optional dependency
    bnb = None  # type: ignore[assignment]
    _BITSANDBYTES_AVAILABLE = False

P = ParamSpec("P")


def _safe_to_bf16(module: torch.nn.Module) -> None:
    try:
        module.to(torch.bfloat16)
    except Exception:
        pass


def _quantize_with_torchao(module: torch.nn.Module, module_name: str) -> bool:
    if not _TORCHAO_AVAILABLE:
        return False
    try:
        quantize_(module, Int8WeightOnlyConfig())  # type: ignore[arg-type]
        print(f"[INT8] torchao weight-only quantization applied to {module_name}")
        return True
    except Exception as exc:
        print(f"[INT8][WARN] torchao quantization failed for {module_name}: {exc}")
        return False


def _convert_linear_to_8bit_lt(linear: torch.nn.Linear) -> torch.nn.Module:
    assert _BITSANDBYTES_AVAILABLE and bnb is not None
    device = linear.weight.device
    bias = linear.bias is not None
    eightbit = bnb.nn.Linear8bitLt(
        linear.in_features,
        linear.out_features,
        bias=bias,
        has_fp16_weights=False,
        device=device,
    )
    eightbit.weight.data.copy_(linear.weight.data)
    if bias:
        eightbit.bias = torch.nn.Parameter(linear.bias.data.to(device))
    return eightbit


def _quantize_with_bitsandbytes(module: torch.nn.Module, module_name: str) -> bool:
    if not _BITSANDBYTES_AVAILABLE:
        return False

    converted_any = False

    def _recursive_swap(parent: torch.nn.Module) -> None:
        nonlocal converted_any
        for name, child in list(parent.named_children()):
            if isinstance(child, torch.nn.Linear):
                converted_child = _convert_linear_to_8bit_lt(child)
                parent._modules[name] = converted_child
                converted_any = True
            else:
                _recursive_swap(child)

    try:
        _recursive_swap(module)
        if converted_any:
            print(f"[INT8] bitsandbytes Linear8bitLt swap applied to {module_name}")
        else:
            print(f"[INT8][WARN] No linear layers found in {module_name} for 8-bit swap")
    except Exception as exc:
        print(f"[INT8][WARN] bitsandbytes swap failed for {module_name}: {exc}")
        return False

    return converted_any


def _quantize_module(module: torch.nn.Module, module_name: str) -> None:
    if _quantize_with_torchao(module, module_name):
        return
    if _quantize_with_bitsandbytes(module, module_name):
        return
    print(f"[INT8][WARN] 8-bit quantization skipped for {module_name} (no backend)")


def optimize_pipeline_int8(
    pipeline: Callable[P, Any],
    *args: P.args,
    **kwargs: P.kwargs,
) -> None:
    """Apply bf16 casting + 8-bit quantization while keeping weights off GPU."""

    torch.set_float32_matmul_precision("high")
    if hasattr(pipeline, "reset_device_map"):
        pipeline.reset_device_map()
    pipeline.to("cpu")

    # This LoRA fusion part remains the same
    pipeline.load_lora_weights(
        "Kijai/WanVideo_comfy", 
        weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
        adapter_name="lightx2v"
    )
    kwargs_lora = {}
    kwargs_lora["load_into_transformer_2"] = True
    pipeline.load_lora_weights(
        "Kijai/WanVideo_comfy", 
        weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
        adapter_name="lightx2v_2", **kwargs_lora
    )
    pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
    pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
    pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
    pipeline.unload_lora_weights()

    _safe_to_bf16(pipeline.transformer)
    _safe_to_bf16(pipeline.transformer_2)
    _quantize_module(pipeline.transformer, "transformer")
    _quantize_module(pipeline.transformer_2, "transformer_2")

    for component_name in ("text_encoder", "vae", "vae.decoder", "vae.encoder"):
        module = pipeline
        try:
            for attr in component_name.split("."):
                module = getattr(module, attr)
            _safe_to_bf16(module)
            _quantize_module(module, component_name)
        except AttributeError:
            continue

    try:
        _safe_to_bf16(pipeline.text_encoder_2)  # type: ignore[attr-defined]
        _quantize_module(pipeline.text_encoder_2, "text_encoder_2")
    except AttributeError:
        pass

    gc = __import__("gc")
    gc.collect()
    torch.cuda.empty_cache()

    if hasattr(pipeline, "enable_sequential_cpu_offload"):
        pipeline.enable_sequential_cpu_offload()
    elif hasattr(pipeline, "enable_model_cpu_offload"):
        pipeline.enable_model_cpu_offload()
    else:
        print("[WARN] Diffusers version lacks CPU offload helpers; keeping pipeline on CPU.")