| | |
| | """ |
| | Load pre-quantized bitsandbytes NF4 models without needing base weights. |
| | |
| | This module provides utilities to load pre-quantized WanModel weights |
| | directly from safetensors/pt files, without downloading or loading |
| | the original FP16/BF16 base model weights. |
| | |
| | Usage: |
| | from load_prequant import load_quantized_model |
| | |
| | model = load_quantized_model("lingbot-world-base-cam/high_noise_model_bnb_nf4") |
| | """ |
| |
|
| | import json |
| | import os |
| | from pathlib import Path |
| | from typing import Optional, Dict, Any, Tuple |
| | from collections import defaultdict |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import bitsandbytes as bnb |
| | from bitsandbytes.functional import QuantState |
| |
|
| | |
| | import sys |
| | sys.path.insert(0, str(Path(__file__).parent)) |
| |
|
| | from wan.modules.model import WanModel |
| |
|
| |
|
| | def replace_linears_with_bnb_nf4( |
| | model: nn.Module, |
| | compute_dtype: torch.dtype = torch.bfloat16, |
| | compress_statistics: bool = True, |
| | quant_type: str = "nf4", |
| | ) -> Tuple[int, Dict[str, Tuple[int, int]]]: |
| | """ |
| | Replace all nn.Linear layers with empty bnb.nn.Linear4bit layers. |
| | |
| | This creates the structure needed to load pre-quantized weights. |
| | The layers are created without weights - they will be populated |
| | by load_state_dict afterwards. |
| | |
| | Args: |
| | model: The model to modify in-place |
| | compute_dtype: Compute dtype for the quantized layers |
| | compress_statistics: Whether to use double quantization |
| | quant_type: Quantization type ('nf4' or 'fp4') |
| | |
| | Returns: |
| | Tuple of (num_replaced, dict mapping layer_name to (in_features, out_features)) |
| | """ |
| | replaced = 0 |
| | layer_shapes = {} |
| |
|
| | |
| | linear_layers = [] |
| | for name, module in model.named_modules(): |
| | if isinstance(module, nn.Linear): |
| | linear_layers.append((name, module)) |
| |
|
| | for name, module in linear_layers: |
| | |
| | parent_name = '.'.join(name.split('.')[:-1]) |
| | child_name = name.split('.')[-1] |
| |
|
| | if parent_name: |
| | parent = model.get_submodule(parent_name) |
| | else: |
| | parent = model |
| |
|
| | |
| | layer_shapes[name] = (module.in_features, module.out_features) |
| |
|
| | |
| | nf4_linear = bnb.nn.Linear4bit( |
| | module.in_features, |
| | module.out_features, |
| | bias=module.bias is not None, |
| | compute_dtype=compute_dtype, |
| | compress_statistics=compress_statistics, |
| | quant_type=quant_type, |
| | ) |
| |
|
| | |
| | setattr(parent, child_name, nf4_linear) |
| | replaced += 1 |
| |
|
| | return replaced, layer_shapes |
| |
|
| |
|
| | def build_model_from_config(config: Dict[str, Any]) -> WanModel: |
| | """ |
| | Build a WanModel instance from config dictionary. |
| | |
| | Args: |
| | config: Dictionary with model configuration |
| | |
| | Returns: |
| | Uninitialized WanModel instance |
| | """ |
| | |
| | model = WanModel( |
| | model_type=config.get("model_type", "i2v"), |
| | patch_size=tuple(config.get("patch_size", (1, 2, 2))), |
| | text_len=config.get("text_len", 512), |
| | in_dim=config.get("in_dim", 16), |
| | dim=config.get("dim", 2048), |
| | ffn_dim=config.get("ffn_dim", 8192), |
| | freq_dim=config.get("freq_dim", 256), |
| | text_dim=config.get("text_dim", 4096), |
| | out_dim=config.get("out_dim", 16), |
| | num_heads=config.get("num_heads", 16), |
| | num_layers=config.get("num_layers", 32), |
| | window_size=tuple(config.get("window_size", (-1, -1))), |
| | qk_norm=config.get("qk_norm", True), |
| | cross_attn_norm=config.get("cross_attn_norm", True), |
| | eps=config.get("eps", 1e-6), |
| | ) |
| |
|
| | return model |
| |
|
| |
|
| | def reconstruct_params4bit_from_components( |
| | weight_components: Dict[str, torch.Tensor], |
| | device: str = "cuda", |
| | ) -> bnb.nn.Params4bit: |
| | """ |
| | Reconstruct a Params4bit object from serialized components using QuantState.from_dict. |
| | |
| | This uses bitsandbytes' own deserialization method for correctness. |
| | |
| | Args: |
| | weight_components: Dict with keys like 'weight', 'absmax', 'quant_map', |
| | 'nested_absmax', 'nested_quant_map', 'quant_state_data' |
| | device: Device to load to |
| | |
| | Returns: |
| | Reconstructed Params4bit |
| | """ |
| | |
| | qs_dict = { |
| | "absmax": weight_components["absmax"], |
| | "quant_map": weight_components["quant_map"], |
| | } |
| |
|
| | |
| | if "nested_absmax" in weight_components: |
| | qs_dict["nested_absmax"] = weight_components["nested_absmax"] |
| | qs_dict["nested_quant_map"] = weight_components["nested_quant_map"] |
| |
|
| | |
| | if "quant_state_data" in weight_components: |
| | qs_dict["quant_state.bitsandbytes__nf4"] = weight_components["quant_state_data"] |
| |
|
| | |
| | quant_state = QuantState.from_dict(qs_dict, device=torch.device(device)) |
| |
|
| | |
| | quantized_weight = weight_components["weight"].to(device) |
| |
|
| | |
| | param = bnb.nn.Params4bit( |
| | data=quantized_weight, |
| | requires_grad=False, |
| | quant_state=quant_state, |
| | bnb_quantized=True, |
| | ) |
| |
|
| | return param |
| |
|
| |
|
| | def load_quantized_state( |
| | model: nn.Module, |
| | weights_path: str, |
| | layer_shapes: Dict[str, Tuple[int, int]], |
| | device: str = "cpu", |
| | ) -> nn.Module: |
| | """ |
| | Load quantized weights into a model with bnb.Linear4bit layers. |
| | |
| | This function handles the special bitsandbytes serialization format |
| | where weights are decomposed into quantized data + metadata tensors. |
| | |
| | Uses QuantState.from_dict for proper deserialization. |
| | |
| | Args: |
| | model: Model with Linear4bit layers already in place |
| | weights_path: Path to model.safetensors or model.pt |
| | layer_shapes: Dict mapping layer names to (in_features, out_features) |
| | device: Device to load quantized weights to |
| | |
| | Returns: |
| | Model with loaded weights |
| | """ |
| | if weights_path.endswith(".safetensors"): |
| | from safetensors.torch import load_file |
| | sd = load_file(weights_path) |
| | else: |
| | sd = torch.load(weights_path, map_location="cpu", weights_only=False) |
| |
|
| | |
| | weight_components = defaultdict(dict) |
| | other_keys = {} |
| |
|
| | |
| | quant_suffixes = [".absmax", ".quant_map", ".nested_absmax", ".nested_quant_map", ".quant_state.bitsandbytes__nf4"] |
| |
|
| | for key, tensor in sd.items(): |
| | base_key = None |
| | component = None |
| |
|
| | if ".weight.absmax" in key: |
| | base_key = key.replace(".weight.absmax", "") |
| | component = "absmax" |
| | elif ".weight.quant_map" in key: |
| | base_key = key.replace(".weight.quant_map", "") |
| | component = "quant_map" |
| | elif ".weight.nested_absmax" in key: |
| | base_key = key.replace(".weight.nested_absmax", "") |
| | component = "nested_absmax" |
| | elif ".weight.nested_quant_map" in key: |
| | base_key = key.replace(".weight.nested_quant_map", "") |
| | component = "nested_quant_map" |
| | elif ".weight.quant_state.bitsandbytes__nf4" in key: |
| | base_key = key.replace(".weight.quant_state.bitsandbytes__nf4", "") |
| | component = "quant_state_data" |
| | elif key.endswith(".weight"): |
| | |
| | potential_base = key[:-7] |
| | has_quant_metadata = any(f"{potential_base}.weight{suffix}" in sd for suffix in quant_suffixes) |
| |
|
| | if has_quant_metadata: |
| | base_key = potential_base |
| | component = "weight" |
| | else: |
| | other_keys[key] = tensor |
| | continue |
| | else: |
| | other_keys[key] = tensor |
| | continue |
| |
|
| | if base_key and component: |
| | weight_components[base_key][component] = tensor |
| |
|
| | |
| | loaded_count = 0 |
| | for name, module in model.named_modules(): |
| | if isinstance(module, bnb.nn.Linear4bit): |
| | if name in weight_components and name in layer_shapes: |
| | components = weight_components[name] |
| |
|
| | if "weight" in components: |
| | |
| | param = reconstruct_params4bit_from_components(components, device=device) |
| | module.weight = param |
| | loaded_count += 1 |
| |
|
| | |
| | bias_key = f"{name}.bias" |
| | if bias_key in other_keys and module.bias is not None: |
| | module.bias.data.copy_(other_keys[bias_key].to(device)) |
| |
|
| | |
| | non_linear_sd = {} |
| | for key, tensor in other_keys.items(): |
| | non_linear_sd[key] = tensor |
| |
|
| | if non_linear_sd: |
| | missing, unexpected = model.load_state_dict(non_linear_sd, strict=False) |
| | expected_missing = {f"{name}.weight" for name in layer_shapes.keys()} |
| | critical_missing = [k for k in missing if k not in expected_missing and not k.endswith("freqs")] |
| | if critical_missing: |
| | print(f"Warning: Missing non-quantized keys: {critical_missing[:10]}...") |
| |
|
| | print(f" Loaded {loaded_count} quantized linear layers") |
| | return model |
| |
|
| |
|
| | def load_quantized_model( |
| | model_dir: str, |
| | device: str = "cuda", |
| | compute_dtype: torch.dtype = torch.bfloat16, |
| | ) -> WanModel: |
| | """ |
| | Load a pre-quantized WanModel from a directory. |
| | |
| | This function: |
| | 1. Reads config.json to get model architecture |
| | 2. Builds an empty WanModel |
| | 3. Replaces Linear layers with bnb.Linear4bit |
| | 4. Loads the pre-quantized weights with proper reconstruction |
| | 5. Moves model to device |
| | |
| | Args: |
| | model_dir: Directory containing config.json and model.safetensors/model.pt |
| | device: Device to load model to |
| | compute_dtype: Compute dtype for quantized layers |
| | |
| | Returns: |
| | Loaded and ready WanModel |
| | """ |
| | model_dir = Path(model_dir) |
| |
|
| | |
| | config_path = model_dir / "config.json" |
| | if not config_path.exists(): |
| | raise FileNotFoundError(f"Config not found: {config_path}") |
| |
|
| | with open(config_path, "r") as f: |
| | config = json.load(f) |
| |
|
| | |
| | meta_path = model_dir / "quantization_meta.json" |
| | if meta_path.exists(): |
| | with open(meta_path, "r") as f: |
| | meta = json.load(f) |
| | quant_config = meta.get("quant", {}) |
| | compute_dtype_str = quant_config.get("compute_dtype", "bfloat16") |
| | compute_dtype = getattr(torch, compute_dtype_str, torch.bfloat16) |
| |
|
| | |
| | safetensors_path = model_dir / "model.safetensors" |
| | pt_path = model_dir / "model.pt" |
| |
|
| | if safetensors_path.exists(): |
| | weights_path = str(safetensors_path) |
| | elif pt_path.exists(): |
| | weights_path = str(pt_path) |
| | else: |
| | raise FileNotFoundError( |
| | f"No weights found in {model_dir}. " |
| | "Expected model.safetensors or model.pt" |
| | ) |
| |
|
| | print(f"Loading pre-quantized model from {model_dir}") |
| | print(f" Config: {config_path}") |
| | print(f" Weights: {weights_path}") |
| |
|
| | |
| | model = build_model_from_config(config) |
| |
|
| | |
| | replaced, layer_shapes = replace_linears_with_bnb_nf4(model, compute_dtype=compute_dtype) |
| | print(f" Replaced {replaced} linear layers with bnb.Linear4bit") |
| |
|
| | |
| | |
| | model = load_quantized_state(model, weights_path, layer_shapes, device=device) |
| |
|
| | |
| | model.to(device) |
| | model.eval() |
| | model.requires_grad_(False) |
| |
|
| | print(f" Model ready on {device}") |
| |
|
| | return model |
| |
|
| |
|
| | def verify_quantized_model(model: nn.Module) -> Dict[str, Any]: |
| | """ |
| | Verify that a model has been properly quantized. |
| | |
| | Args: |
| | model: Model to verify |
| | |
| | Returns: |
| | Dictionary with verification results |
| | """ |
| | total_params = 0 |
| | quantized_params = 0 |
| | linear4bit_count = 0 |
| | regular_linear_count = 0 |
| |
|
| | for name, module in model.named_modules(): |
| | if isinstance(module, bnb.nn.Linear4bit): |
| | linear4bit_count += 1 |
| | if hasattr(module.weight, 'quant_state') and module.weight.quant_state is not None: |
| | quantized_params += module.weight.numel() |
| | elif isinstance(module, nn.Linear): |
| | regular_linear_count += 1 |
| |
|
| | for param in model.parameters(): |
| | total_params += param.numel() |
| |
|
| | return { |
| | "total_params": total_params, |
| | "quantized_params": quantized_params, |
| | "linear4bit_count": linear4bit_count, |
| | "regular_linear_count": regular_linear_count, |
| | "is_quantized": linear4bit_count > 0 and regular_linear_count == 0, |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="Test loading pre-quantized model") |
| | parser.add_argument("model_dir", type=str, help="Path to quantized model directory") |
| | parser.add_argument("--device", type=str, default="cuda", help="Device to load to") |
| | args = parser.parse_args() |
| |
|
| | model = load_quantized_model(args.model_dir, device=args.device) |
| |
|
| | info = verify_quantized_model(model) |
| | print(f"\nVerification:") |
| | print(f" Linear4bit layers: {info['linear4bit_count']}") |
| | print(f" Regular Linear layers: {info['regular_linear_count']}") |
| | print(f" Is properly quantized: {info['is_quantized']}") |
| |
|