"""Dequantize QLoRA checkpoint → fp32 Linear model (~970K params).""" import os, sys, torch, torch.nn as nn sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from scripts.model_tiny import TinyModel, QLoRALinear def dequantize_model(qlora_ckpt_path, output_path): model = TinyModel() model.reset_weights() # Apply QLoRA to create matching architecture from scripts.model_tiny import apply_qlora model = apply_qlora(model, r=8, alpha=16, freeze_embeds=False) # Load QLoRA checkpoint ckpt = torch.load(qlora_ckpt_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt, strict=False) # Fix _has_bias flag (not in state dict, but bias buffer was loaded) for name, module in model.named_modules(): if isinstance(module, QLoRALinear): bias_key = f"{name}.bias" if bias_key in ckpt: module._has_bias = True model.eval() # Dequantize: replace QLoRALinear → nn.Linear with dequantized weights for name, module in model.named_modules(): if isinstance(module, QLoRALinear): weight = module._dequantized_weight() + module.lora_B @ module.lora_A * module.scaling linear = nn.Linear(module.in_features, module.out_features, bias=False) linear.weight = nn.Parameter(weight) # Replace in parent parent = model parts = name.split(".") for p in parts[:-1]: parent = getattr(parent, p) setattr(parent, parts[-1], linear) # Save fp32 state dict torch.save(model.state_dict(), output_path) n = sum(p.numel() for p in model.parameters()) print(f"Dequantized: {n:,} params → {output_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--input", default="best.pt") parser.add_argument("--output", default="best_fp32.pt") args = parser.parse_args() dequantize_model(args.input, args.output)