c2cite / tests /peft_generation.py
loadingy's picture
first push
51be264
import fire
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def main(
base_model: str, instruction: str, lora_weights: str = None, device: str = "cuda:0"
):
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.bfloat16,
device_map=device,
)
model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float32)
input_ids = tokenizer(instruction, return_tensors="pt").input_ids.to(device)
output = ""
with torch.inference_mode():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=128,
)
output = tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][input_ids.shape[-1] :]
print(f"Prompt:\n{instruction}\n")
print(f"Generated:\n{output}")
if __name__ == "__main__":
fire.Fire(main)