import argparse from typing import Optional import torch from peft import PeftConfig, PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer def resolve_dtype(name: str, device: str) -> Optional[torch.dtype]: if name == "auto": # On CPU, default to fp32; otherwise let transformers pick. return torch.float32 if device == "cpu" else None return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name] def resolve_device_map(device: str): return None if device == "cpu" else "auto" def parse_args(): p = argparse.ArgumentParser(description="Run a quick LoRA inference.") p.add_argument("--lora_dir", default="outputs/tinyllama-lora", help="Path to LoRA adapter folder.") p.add_argument("--prompt", default="### Instruction:\nExplain LoRA in one sentence.\n\n### Input:\nN/A\n\n### Response:\n") p.add_argument("--max_new_tokens", type=int, default=128) p.add_argument("--temperature", type=float, default=0.7) p.add_argument("--top_p", type=float, default=0.9) p.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda", "mps"]) p.add_argument("--torch_dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"]) p.add_argument("--offload_dir", default=None, help="Optional offload directory when using device_map='auto'.") return p.parse_args() def main(): args = parse_args() cfg = PeftConfig.from_pretrained(args.lora_dir) base_model = cfg.base_model_name_or_path # base model id/path torch_dtype = resolve_dtype(args.torch_dtype, args.device) device_map = resolve_device_map(args.device) if args.device != "auto" else "auto" tokenizer = AutoTokenizer.from_pretrained(args.lora_dir, use_fast=False) model = AutoModelForCausalLM.from_pretrained( base_model, device_map=device_map, torch_dtype=torch_dtype, offload_folder=args.offload_dir, ) model = PeftModel.from_pretrained(model, args.lora_dir) prompt = args.prompt inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=True, temperature=args.temperature, top_p=args.top_p, ) print(tokenizer.decode(out[0], skip_special_tokens=True)) if __name__ == "__main__": main()