| from transformers import pipeline, set_seed | |
| import re | |
| from transformers import set_seed | |
| model_ckpt = './' | |
| generation = pipeline('text-generation', model=model_ckpt, device=0) | |
| def first_block(string): | |
| return re.split('\nclass|\ndef|\n#|\n@|\nprint|\nif', string)[0].rstrip() | |
| def complete_code(pipe, prompt, max_length=64, num_completions=4, seed=1): | |
| set_seed(seed) | |
| gen_kwargs = {"temperature":0.4, "top_p":0.95, "top_k":0, "num_beams":1, | |
| "do_sample":True,} | |
| code_gens = generation(prompt, num_return_sequences=num_completions, | |
| max_length=max_length, **gen_kwargs) | |
| code_strings = [] | |
| for code_gen in code_gens: | |
| generated_code = first_block(code_gen['generated_text'][len(prompt):]) | |
| code_strings.append(generated_code) | |
| print(('\n'+'='*80 + '\n').join(code_strings)) | |
| prompt = '''def area_of_rectangle(a: float, b: float): | |
| """Return the area of the rectangle."""''' | |
| complete_code(generation, prompt) |