Upload inference.py
Browse files- inference.py +25 -0
inference.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline, set_seed
|
| 2 |
+
import re
|
| 3 |
+
from transformers import set_seed
|
| 4 |
+
model_ckpt = './'
|
| 5 |
+
generation = pipeline('text-generation', model=model_ckpt, device=0)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def first_block(string):
|
| 9 |
+
return re.split('\nclass|\ndef|\n#|\n@|\nprint|\nif', string)[0].rstrip()
|
| 10 |
+
|
| 11 |
+
def complete_code(pipe, prompt, max_length=64, num_completions=4, seed=1):
|
| 12 |
+
set_seed(seed)
|
| 13 |
+
gen_kwargs = {"temperature":0.4, "top_p":0.95, "top_k":0, "num_beams":1,
|
| 14 |
+
"do_sample":True,}
|
| 15 |
+
code_gens = generation(prompt, num_return_sequences=num_completions,
|
| 16 |
+
max_length=max_length, **gen_kwargs)
|
| 17 |
+
code_strings = []
|
| 18 |
+
for code_gen in code_gens:
|
| 19 |
+
generated_code = first_block(code_gen['generated_text'][len(prompt):])
|
| 20 |
+
code_strings.append(generated_code)
|
| 21 |
+
print(('\n'+'='*80 + '\n').join(code_strings))
|
| 22 |
+
|
| 23 |
+
prompt = '''def area_of_rectangle(a: float, b: float):
|
| 24 |
+
"""Return the area of the rectangle."""'''
|
| 25 |
+
complete_code(generation, prompt)
|