samaresh55 commited on
Commit
97b4331
·
1 Parent(s): ac7f519

Upload test_model.py

Browse files
Files changed (1) hide show
  1. test_model.py +75 -0
test_model.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ import transformers
4
+
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
6
+
7
+ tokenizer = LlamaTokenizer.from_pretrained("model/")
8
+
9
+ model = LlamaForCausalLM.from_pretrained(
10
+ "decapoda-research/llama-7b-hf",
11
+ load_in_8bit=True,
12
+ torch_dtype=torch.float16,
13
+ device_map="auto",
14
+ )
15
+
16
+ model = PeftModel.from_pretrained(
17
+ "model/",
18
+ torch_dtype=torch.float16,
19
+ device_map="auto",
20
+ load_in_8bit = True
21
+ )
22
+
23
+ def generate_prompt(instruction, input=None):
24
+ if input:
25
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
26
+
27
+ ### Instruction:
28
+ {instruction}
29
+
30
+ ### Input:
31
+ {input}
32
+
33
+ ### Response:"""
34
+ else:
35
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
36
+
37
+ ### Instruction:
38
+ {instruction}
39
+
40
+ ### Response:"""
41
+
42
+
43
+ model.eval()
44
+
45
+
46
+ def evaluate(
47
+ instruction,
48
+ input=None,
49
+ temperature=0.1,
50
+ top_p=0.75,
51
+ top_k=40,
52
+ num_beams=4,
53
+ **kwargs,
54
+ ):
55
+ prompt = generate_prompt(instruction, input)
56
+ inputs = tokenizer(prompt, return_tensors="pt")
57
+ input_ids = inputs["input_ids"].to(device)
58
+ generation_config = GenerationConfig(
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ num_beams=num_beams,
63
+ **kwargs,
64
+ )
65
+ with torch.no_grad():
66
+ generation_output = model.generate(
67
+ input_ids=input_ids,
68
+ generation_config=generation_config,
69
+ return_dict_in_generate=True,
70
+ output_scores=True,
71
+ max_new_tokens=2048,
72
+ )
73
+ s = generation_output.sequences[0]
74
+ output = tokenizer.decode(s)
75
+ return output.split("### Response:")[1].strip()