| """Dequantize QLoRA checkpoint → fp32 Linear model (~970K params).""" |
| import os, sys, torch, torch.nn as nn |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| from scripts.model_tiny import TinyModel, QLoRALinear |
|
|
| def dequantize_model(qlora_ckpt_path, output_path): |
| model = TinyModel() |
| model.reset_weights() |
|
|
| |
| from scripts.model_tiny import apply_qlora |
| model = apply_qlora(model, r=8, alpha=16, freeze_embeds=False) |
|
|
| |
| ckpt = torch.load(qlora_ckpt_path, map_location="cpu", weights_only=True) |
| model.load_state_dict(ckpt, strict=False) |
|
|
| |
| for name, module in model.named_modules(): |
| if isinstance(module, QLoRALinear): |
| bias_key = f"{name}.bias" |
| if bias_key in ckpt: |
| module._has_bias = True |
|
|
| model.eval() |
|
|
| |
| for name, module in model.named_modules(): |
| if isinstance(module, QLoRALinear): |
| weight = module._dequantized_weight() + module.lora_B @ module.lora_A * module.scaling |
| linear = nn.Linear(module.in_features, module.out_features, bias=False) |
| linear.weight = nn.Parameter(weight) |
|
|
| |
| parent = model |
| parts = name.split(".") |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| setattr(parent, parts[-1], linear) |
|
|
| |
| torch.save(model.state_dict(), output_path) |
| n = sum(p.numel() for p in model.parameters()) |
| print(f"Dequantized: {n:,} params → {output_path}") |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input", default="best.pt") |
| parser.add_argument("--output", default="best_fp32.pt") |
| args = parser.parse_args() |
| dequantize_model(args.input, args.output) |
|
|