S-Dreamer commited on
Commit
edaa68a
·
verified ·
1 Parent(s): 00ae6eb

Create infer.py

Browse files
Files changed (1) hide show
  1. src/infer.py +36 -0
src/infer.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ from peft import PeftModel
5
+
6
+
7
+ def load_generator(base_model: str, adapter_dir: str):
8
+ tok = AutoTokenizer.from_pretrained(adapter_dir, use_fast=True)
9
+ if tok.pad_token is None:
10
+ tok.pad_token = tok.eos_token
11
+
12
+ model = AutoModelForCausalLM.from_pretrained(base_model)
13
+ model = PeftModel.from_pretrained(model, adapter_dir)
14
+ model.eval()
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model.to(device)
18
+
19
+ return {"model": model, "tokenizer": tok, "device": device}
20
+
21
+
22
+ @torch.no_grad()
23
+ def generate_text(gen, prompt: str, max_new_tokens: int = 80, temperature: float = 0.9) -> str:
24
+ model = gen["model"]
25
+ tok = gen["tokenizer"]
26
+ device = gen["device"]
27
+
28
+ inputs = tok(prompt, return_tensors="pt").to(device)
29
+ out = model.generate(
30
+ **inputs,
31
+ max_new_tokens=int(max_new_tokens),
32
+ do_sample=True,
33
+ temperature=float(temperature),
34
+ pad_token_id=tok.eos_token_id,
35
+ )
36
+ return tok.decode(out[0], skip_special_tokens=True)