import argparse from typing import Optional import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer def parse_args(): p = argparse.ArgumentParser(description="Compare base vs. fine-tuned LoRA outputs.") p.add_argument("--base_model", required=True, help="Base model id or local path.") p.add_argument("--lora_dir", required=True, help="Path to LoRA adapter folder (e.g., outputs/tinyllama-lora).") p.add_argument("--prompt", required=True, help="Prompt to generate with.") p.add_argument("--max_new_tokens", type=int, default=128) p.add_argument("--temperature", type=float, default=0.7) p.add_argument("--top_p", type=float, default=0.9) p.add_argument( "--torch_dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"], help="Force dtype for model load.", ) p.add_argument( "--device", default="auto", choices=["auto", "cpu", "cuda", "mps"], help="Force device map; on CPU use this to keep everything on host.", ) p.add_argument( "--offload_dir", default=None, help="Optional offload directory when using device_map='auto' on constrained devices.", ) return p.parse_args() def resolve_dtype(name: str) -> Optional[torch.dtype]: if name == "auto": return None return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name] def resolve_device_map(device: str): return None if device == "cpu" else "auto" def generate(model, tokenizer, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, ) return tokenizer.decode(output[0], skip_special_tokens=True) def main(): args = parse_args() force_cpu = args.device == "cpu" torch_dtype = torch.float32 if force_cpu else resolve_dtype(args.torch_dtype) device_map = resolve_device_map(args.device) if args.device != "auto" else "auto" tokenizer = AutoTokenizer.from_pretrained(args.lora_dir, use_fast=False) base_model = AutoModelForCausalLM.from_pretrained( args.base_model, device_map=device_map, torch_dtype=torch_dtype, offload_folder=args.offload_dir, ) lora_wrapped = AutoModelForCausalLM.from_pretrained( args.base_model, device_map=device_map, torch_dtype=torch_dtype, offload_folder=args.offload_dir, ) lora_wrapped = PeftModel.from_pretrained(lora_wrapped, args.lora_dir) if force_cpu: # Avoid Accelerate dispatch/offload; keep everything on CPU. base_model.to("cpu") lora_wrapped.to("cpu") base_out = generate( base_model, tokenizer, args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, ) lora_out = generate( lora_wrapped, tokenizer, args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, ) print("=== Base model ===") print(base_out) print("\n=== LoRA model ===") print(lora_out) if __name__ == "__main__": main()