dee-tulu-train / evaluation /compare_lora.py
Javad Taghia
cput ok for compare
dbb959c
import argparse
from typing import Optional
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args():
p = argparse.ArgumentParser(description="Compare base vs. fine-tuned LoRA outputs.")
p.add_argument("--base_model", required=True, help="Base model id or local path.")
p.add_argument("--lora_dir", required=True, help="Path to LoRA adapter folder (e.g., outputs/tinyllama-lora).")
p.add_argument("--prompt", required=True, help="Prompt to generate with.")
p.add_argument("--max_new_tokens", type=int, default=128)
p.add_argument("--temperature", type=float, default=0.7)
p.add_argument("--top_p", type=float, default=0.9)
p.add_argument(
"--torch_dtype",
default="auto",
choices=["auto", "float16", "bfloat16", "float32"],
help="Force dtype for model load.",
)
p.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda", "mps"],
help="Force device map; on CPU use this to keep everything on host.",
)
p.add_argument(
"--offload_dir",
default=None,
help="Optional offload directory when using device_map='auto' on constrained devices.",
)
return p.parse_args()
def resolve_dtype(name: str) -> Optional[torch.dtype]:
if name == "auto":
return None
return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name]
def resolve_device_map(device: str):
return None if device == "cpu" else "auto"
def generate(model, tokenizer, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
return tokenizer.decode(output[0], skip_special_tokens=True)
def main():
args = parse_args()
force_cpu = args.device == "cpu"
torch_dtype = torch.float32 if force_cpu else resolve_dtype(args.torch_dtype)
device_map = resolve_device_map(args.device) if args.device != "auto" else "auto"
tokenizer = AutoTokenizer.from_pretrained(args.lora_dir, use_fast=False)
base_model = AutoModelForCausalLM.from_pretrained(
args.base_model,
device_map=device_map,
torch_dtype=torch_dtype,
offload_folder=args.offload_dir,
)
lora_wrapped = AutoModelForCausalLM.from_pretrained(
args.base_model,
device_map=device_map,
torch_dtype=torch_dtype,
offload_folder=args.offload_dir,
)
lora_wrapped = PeftModel.from_pretrained(lora_wrapped, args.lora_dir)
if force_cpu:
# Avoid Accelerate dispatch/offload; keep everything on CPU.
base_model.to("cpu")
lora_wrapped.to("cpu")
base_out = generate(
base_model,
tokenizer,
args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
lora_out = generate(
lora_wrapped,
tokenizer,
args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
print("=== Base model ===")
print(base_out)
print("\n=== LoRA model ===")
print(lora_out)
if __name__ == "__main__":
main()