TRM-coding commited on
Commit
377e9d8
·
verified ·
1 Parent(s): 9b10b2a

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +25 -0
inference.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, set_seed
2
+ import re
3
+ from transformers import set_seed
4
+ model_ckpt = './'
5
+ generation = pipeline('text-generation', model=model_ckpt, device=0)
6
+
7
+
8
+ def first_block(string):
9
+ return re.split('\nclass|\ndef|\n#|\n@|\nprint|\nif', string)[0].rstrip()
10
+
11
+ def complete_code(pipe, prompt, max_length=64, num_completions=4, seed=1):
12
+ set_seed(seed)
13
+ gen_kwargs = {"temperature":0.4, "top_p":0.95, "top_k":0, "num_beams":1,
14
+ "do_sample":True,}
15
+ code_gens = generation(prompt, num_return_sequences=num_completions,
16
+ max_length=max_length, **gen_kwargs)
17
+ code_strings = []
18
+ for code_gen in code_gens:
19
+ generated_code = first_block(code_gen['generated_text'][len(prompt):])
20
+ code_strings.append(generated_code)
21
+ print(('\n'+'='*80 + '\n').join(code_strings))
22
+
23
+ prompt = '''def area_of_rectangle(a: float, b: float):
24
+ """Return the area of the rectangle."""'''
25
+ complete_code(generation, prompt)