PythonCopilot / inference.py
TRM-coding's picture
Upload inference.py
377e9d8 verified
raw
history blame contribute delete
994 Bytes
from transformers import pipeline, set_seed
import re
from transformers import set_seed
model_ckpt = './'
generation = pipeline('text-generation', model=model_ckpt, device=0)
def first_block(string):
return re.split('\nclass|\ndef|\n#|\n@|\nprint|\nif', string)[0].rstrip()
def complete_code(pipe, prompt, max_length=64, num_completions=4, seed=1):
set_seed(seed)
gen_kwargs = {"temperature":0.4, "top_p":0.95, "top_k":0, "num_beams":1,
"do_sample":True,}
code_gens = generation(prompt, num_return_sequences=num_completions,
max_length=max_length, **gen_kwargs)
code_strings = []
for code_gen in code_gens:
generated_code = first_block(code_gen['generated_text'][len(prompt):])
code_strings.append(generated_code)
print(('\n'+'='*80 + '\n').join(code_strings))
prompt = '''def area_of_rectangle(a: float, b: float):
"""Return the area of the rectangle."""'''
complete_code(generation, prompt)