lumia-tiny / dequantize_qlora.py
samcheng0's picture
Upload dequantize_qlora.py with huggingface_hub
d94fc9d verified
Raw
History Blame Contribute Delete
2.03 kB
"""Dequantize QLoRA checkpoint → fp32 Linear model (~970K params)."""
import os, sys, torch, torch.nn as nn
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from scripts.model_tiny import TinyModel, QLoRALinear
def dequantize_model(qlora_ckpt_path, output_path):
model = TinyModel()
model.reset_weights()
# Apply QLoRA to create matching architecture
from scripts.model_tiny import apply_qlora
model = apply_qlora(model, r=8, alpha=16, freeze_embeds=False)
# Load QLoRA checkpoint
ckpt = torch.load(qlora_ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt, strict=False)
# Fix _has_bias flag (not in state dict, but bias buffer was loaded)
for name, module in model.named_modules():
if isinstance(module, QLoRALinear):
bias_key = f"{name}.bias"
if bias_key in ckpt:
module._has_bias = True
model.eval()
# Dequantize: replace QLoRALinear → nn.Linear with dequantized weights
for name, module in model.named_modules():
if isinstance(module, QLoRALinear):
weight = module._dequantized_weight() + module.lora_B @ module.lora_A * module.scaling
linear = nn.Linear(module.in_features, module.out_features, bias=False)
linear.weight = nn.Parameter(weight)
# Replace in parent
parent = model
parts = name.split(".")
for p in parts[:-1]:
parent = getattr(parent, p)
setattr(parent, parts[-1], linear)
# Save fp32 state dict
torch.save(model.state_dict(), output_path)
n = sum(p.numel() for p in model.parameters())
print(f"Dequantized: {n:,} params → {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="best.pt")
parser.add_argument("--output", default="best_fp32.pt")
args = parser.parse_args()
dequantize_model(args.input, args.output)