lingbot-world-base-cam-nf4 / load_prequant.py
cahlen
Add pre-quantized NF4 weights and complete inference package
8c3bc34
#!/usr/bin/env python3
"""
Load pre-quantized bitsandbytes NF4 models without needing base weights.
This module provides utilities to load pre-quantized WanModel weights
directly from safetensors/pt files, without downloading or loading
the original FP16/BF16 base model weights.
Usage:
from load_prequant import load_quantized_model
model = load_quantized_model("lingbot-world-base-cam/high_noise_model_bnb_nf4")
"""
import json
import os
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from collections import defaultdict
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bitsandbytes.functional import QuantState
# Add parent to path for wan imports
import sys
sys.path.insert(0, str(Path(__file__).parent))
from wan.modules.model import WanModel
def replace_linears_with_bnb_nf4(
model: nn.Module,
compute_dtype: torch.dtype = torch.bfloat16,
compress_statistics: bool = True,
quant_type: str = "nf4",
) -> Tuple[int, Dict[str, Tuple[int, int]]]:
"""
Replace all nn.Linear layers with empty bnb.nn.Linear4bit layers.
This creates the structure needed to load pre-quantized weights.
The layers are created without weights - they will be populated
by load_state_dict afterwards.
Args:
model: The model to modify in-place
compute_dtype: Compute dtype for the quantized layers
compress_statistics: Whether to use double quantization
quant_type: Quantization type ('nf4' or 'fp4')
Returns:
Tuple of (num_replaced, dict mapping layer_name to (in_features, out_features))
"""
replaced = 0
layer_shapes = {}
# Collect all linear layers first to avoid modifying during iteration
linear_layers = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
linear_layers.append((name, module))
for name, module in linear_layers:
# Get parent module
parent_name = '.'.join(name.split('.')[:-1])
child_name = name.split('.')[-1]
if parent_name:
parent = model.get_submodule(parent_name)
else:
parent = model
# Store original shape for reconstruction
layer_shapes[name] = (module.in_features, module.out_features)
# Create empty NF4 linear layer with same shape
nf4_linear = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=module.bias is not None,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
# Replace the layer (weights will be loaded from state_dict)
setattr(parent, child_name, nf4_linear)
replaced += 1
return replaced, layer_shapes
def build_model_from_config(config: Dict[str, Any]) -> WanModel:
"""
Build a WanModel instance from config dictionary.
Args:
config: Dictionary with model configuration
Returns:
Uninitialized WanModel instance
"""
# Extract config values with defaults matching WanModel.__init__
model = WanModel(
model_type=config.get("model_type", "i2v"),
patch_size=tuple(config.get("patch_size", (1, 2, 2))),
text_len=config.get("text_len", 512),
in_dim=config.get("in_dim", 16),
dim=config.get("dim", 2048),
ffn_dim=config.get("ffn_dim", 8192),
freq_dim=config.get("freq_dim", 256),
text_dim=config.get("text_dim", 4096),
out_dim=config.get("out_dim", 16),
num_heads=config.get("num_heads", 16),
num_layers=config.get("num_layers", 32),
window_size=tuple(config.get("window_size", (-1, -1))),
qk_norm=config.get("qk_norm", True),
cross_attn_norm=config.get("cross_attn_norm", True),
eps=config.get("eps", 1e-6),
)
return model
def reconstruct_params4bit_from_components(
weight_components: Dict[str, torch.Tensor],
device: str = "cuda",
) -> bnb.nn.Params4bit:
"""
Reconstruct a Params4bit object from serialized components using QuantState.from_dict.
This uses bitsandbytes' own deserialization method for correctness.
Args:
weight_components: Dict with keys like 'weight', 'absmax', 'quant_map',
'nested_absmax', 'nested_quant_map', 'quant_state_data'
device: Device to load to
Returns:
Reconstructed Params4bit
"""
# Build the dict that QuantState.from_dict expects
qs_dict = {
"absmax": weight_components["absmax"],
"quant_map": weight_components["quant_map"],
}
# Add nested quantization components if present (double quantization)
if "nested_absmax" in weight_components:
qs_dict["nested_absmax"] = weight_components["nested_absmax"]
qs_dict["nested_quant_map"] = weight_components["nested_quant_map"]
# Add the packed quant_state data (contains shape, dtype, etc.)
if "quant_state_data" in weight_components:
qs_dict["quant_state.bitsandbytes__nf4"] = weight_components["quant_state_data"]
# Use bitsandbytes' own deserialization
quant_state = QuantState.from_dict(qs_dict, device=torch.device(device))
# Get quantized weight and move to device
quantized_weight = weight_components["weight"].to(device)
# Create Params4bit with the quantized data
param = bnb.nn.Params4bit(
data=quantized_weight,
requires_grad=False,
quant_state=quant_state,
bnb_quantized=True, # Already quantized, don't re-quantize on .to()
)
return param
def load_quantized_state(
model: nn.Module,
weights_path: str,
layer_shapes: Dict[str, Tuple[int, int]],
device: str = "cpu",
) -> nn.Module:
"""
Load quantized weights into a model with bnb.Linear4bit layers.
This function handles the special bitsandbytes serialization format
where weights are decomposed into quantized data + metadata tensors.
Uses QuantState.from_dict for proper deserialization.
Args:
model: Model with Linear4bit layers already in place
weights_path: Path to model.safetensors or model.pt
layer_shapes: Dict mapping layer names to (in_features, out_features)
device: Device to load quantized weights to
Returns:
Model with loaded weights
"""
if weights_path.endswith(".safetensors"):
from safetensors.torch import load_file
sd = load_file(weights_path)
else:
sd = torch.load(weights_path, map_location="cpu", weights_only=False)
# Group keys by their base weight name
weight_components = defaultdict(dict)
other_keys = {}
# Quantization-related suffixes
quant_suffixes = [".absmax", ".quant_map", ".nested_absmax", ".nested_quant_map", ".quant_state.bitsandbytes__nf4"]
for key, tensor in sd.items():
base_key = None
component = None
if ".weight.absmax" in key:
base_key = key.replace(".weight.absmax", "")
component = "absmax"
elif ".weight.quant_map" in key:
base_key = key.replace(".weight.quant_map", "")
component = "quant_map"
elif ".weight.nested_absmax" in key:
base_key = key.replace(".weight.nested_absmax", "")
component = "nested_absmax"
elif ".weight.nested_quant_map" in key:
base_key = key.replace(".weight.nested_quant_map", "")
component = "nested_quant_map"
elif ".weight.quant_state.bitsandbytes__nf4" in key:
base_key = key.replace(".weight.quant_state.bitsandbytes__nf4", "")
component = "quant_state_data"
elif key.endswith(".weight"):
# Check if this is a quantized linear weight or regular weight
potential_base = key[:-7] # Remove ".weight"
has_quant_metadata = any(f"{potential_base}.weight{suffix}" in sd for suffix in quant_suffixes)
if has_quant_metadata:
base_key = potential_base
component = "weight"
else:
other_keys[key] = tensor
continue
else:
other_keys[key] = tensor
continue
if base_key and component:
weight_components[base_key][component] = tensor
# Load quantized weights into model
loaded_count = 0
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
if name in weight_components and name in layer_shapes:
components = weight_components[name]
if "weight" in components:
# Use bitsandbytes' own deserialization via QuantState.from_dict
param = reconstruct_params4bit_from_components(components, device=device)
module.weight = param
loaded_count += 1
# Load bias if present
bias_key = f"{name}.bias"
if bias_key in other_keys and module.bias is not None:
module.bias.data.copy_(other_keys[bias_key].to(device))
# Load non-quantized weights (embeddings, norms, biases, etc.)
non_linear_sd = {}
for key, tensor in other_keys.items():
non_linear_sd[key] = tensor
if non_linear_sd:
missing, unexpected = model.load_state_dict(non_linear_sd, strict=False)
expected_missing = {f"{name}.weight" for name in layer_shapes.keys()}
critical_missing = [k for k in missing if k not in expected_missing and not k.endswith("freqs")]
if critical_missing:
print(f"Warning: Missing non-quantized keys: {critical_missing[:10]}...")
print(f" Loaded {loaded_count} quantized linear layers")
return model
def load_quantized_model(
model_dir: str,
device: str = "cuda",
compute_dtype: torch.dtype = torch.bfloat16,
) -> WanModel:
"""
Load a pre-quantized WanModel from a directory.
This function:
1. Reads config.json to get model architecture
2. Builds an empty WanModel
3. Replaces Linear layers with bnb.Linear4bit
4. Loads the pre-quantized weights with proper reconstruction
5. Moves model to device
Args:
model_dir: Directory containing config.json and model.safetensors/model.pt
device: Device to load model to
compute_dtype: Compute dtype for quantized layers
Returns:
Loaded and ready WanModel
"""
model_dir = Path(model_dir)
# Load config
config_path = model_dir / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config not found: {config_path}")
with open(config_path, "r") as f:
config = json.load(f)
# Check for quantization metadata (optional but recommended)
meta_path = model_dir / "quantization_meta.json"
if meta_path.exists():
with open(meta_path, "r") as f:
meta = json.load(f)
quant_config = meta.get("quant", {})
compute_dtype_str = quant_config.get("compute_dtype", "bfloat16")
compute_dtype = getattr(torch, compute_dtype_str, torch.bfloat16)
# Find weights file (prefer safetensors)
safetensors_path = model_dir / "model.safetensors"
pt_path = model_dir / "model.pt"
if safetensors_path.exists():
weights_path = str(safetensors_path)
elif pt_path.exists():
weights_path = str(pt_path)
else:
raise FileNotFoundError(
f"No weights found in {model_dir}. "
"Expected model.safetensors or model.pt"
)
print(f"Loading pre-quantized model from {model_dir}")
print(f" Config: {config_path}")
print(f" Weights: {weights_path}")
# Build model from config (creates initialized weights we'll replace)
model = build_model_from_config(config)
# Replace Linear → Linear4bit (empty, ready for state_dict)
replaced, layer_shapes = replace_linears_with_bnb_nf4(model, compute_dtype=compute_dtype)
print(f" Replaced {replaced} linear layers with bnb.Linear4bit")
# Load quantized weights with proper reconstruction
# This loads quantized weights directly to the target device
model = load_quantized_state(model, weights_path, layer_shapes, device=device)
# Move non-quantized parts to device and set eval mode
model.to(device)
model.eval()
model.requires_grad_(False)
print(f" Model ready on {device}")
return model
def verify_quantized_model(model: nn.Module) -> Dict[str, Any]:
"""
Verify that a model has been properly quantized.
Args:
model: Model to verify
Returns:
Dictionary with verification results
"""
total_params = 0
quantized_params = 0
linear4bit_count = 0
regular_linear_count = 0
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
linear4bit_count += 1
if hasattr(module.weight, 'quant_state') and module.weight.quant_state is not None:
quantized_params += module.weight.numel()
elif isinstance(module, nn.Linear):
regular_linear_count += 1
for param in model.parameters():
total_params += param.numel()
return {
"total_params": total_params,
"quantized_params": quantized_params,
"linear4bit_count": linear4bit_count,
"regular_linear_count": regular_linear_count,
"is_quantized": linear4bit_count > 0 and regular_linear_count == 0,
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test loading pre-quantized model")
parser.add_argument("model_dir", type=str, help="Path to quantized model directory")
parser.add_argument("--device", type=str, default="cuda", help="Device to load to")
args = parser.parse_args()
model = load_quantized_model(args.model_dir, device=args.device)
info = verify_quantized_model(model)
print(f"\nVerification:")
print(f" Linear4bit layers: {info['linear4bit_count']}")
print(f" Regular Linear layers: {info['regular_linear_count']}")
print(f" Is properly quantized: {info['is_quantized']}")