dee-tulu-train / evaluation /simple_inference.py
Javad Taghia
cput ok for compare
dbb959c
import argparse
from typing import Optional
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def resolve_dtype(name: str, device: str) -> Optional[torch.dtype]:
if name == "auto":
# On CPU, default to fp32; otherwise let transformers pick.
return torch.float32 if device == "cpu" else None
return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name]
def resolve_device_map(device: str):
return None if device == "cpu" else "auto"
def parse_args():
p = argparse.ArgumentParser(description="Run a quick LoRA inference.")
p.add_argument("--lora_dir", default="outputs/tinyllama-lora", help="Path to LoRA adapter folder.")
p.add_argument("--prompt", default="### Instruction:\nExplain LoRA in one sentence.\n\n### Input:\nN/A\n\n### Response:\n")
p.add_argument("--max_new_tokens", type=int, default=128)
p.add_argument("--temperature", type=float, default=0.7)
p.add_argument("--top_p", type=float, default=0.9)
p.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda", "mps"])
p.add_argument("--torch_dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"])
p.add_argument("--offload_dir", default=None, help="Optional offload directory when using device_map='auto'.")
return p.parse_args()
def main():
args = parse_args()
cfg = PeftConfig.from_pretrained(args.lora_dir)
base_model = cfg.base_model_name_or_path # base model id/path
torch_dtype = resolve_dtype(args.torch_dtype, args.device)
device_map = resolve_device_map(args.device) if args.device != "auto" else "auto"
tokenizer = AutoTokenizer.from_pretrained(args.lora_dir, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=device_map,
torch_dtype=torch_dtype,
offload_folder=args.offload_dir,
)
model = PeftModel.from_pretrained(model, args.lora_dir)
prompt = args.prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
if __name__ == "__main__":
main()