S-Dreamer's picture
Create infer.py
edaa68a verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
def load_generator(base_model: str, adapter_dir: str):
tok = AutoTokenizer.from_pretrained(adapter_dir, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(base_model)
model = PeftModel.from_pretrained(model, adapter_dir)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return {"model": model, "tokenizer": tok, "device": device}
@torch.no_grad()
def generate_text(gen, prompt: str, max_new_tokens: int = 80, temperature: float = 0.9) -> str:
model = gen["model"]
tok = gen["tokenizer"]
device = gen["device"]
inputs = tok(prompt, return_tensors="pt").to(device)
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
pad_token_id=tok.eos_token_id,
)
return tok.decode(out[0], skip_special_tokens=True)