Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import requests | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| header = """ | |
| import psycopg2 | |
| conn = psycopg2.connect("CONN") | |
| cur = conn.cursor() | |
| def set_customer_name(id, new_name): | |
| # PROMPT | |
| cur.execute("UPDATE customer SET name | |
| """ | |
| modelPath = { | |
| "GPT2-Medium": "gpt2-medium", | |
| "CodeParrot-mini": "codeparrot/codeparrot-small", | |
| "CodeGen-350-Mono": "Salesforce/codegen-350M-mono", | |
| "GPT-J": "EleutherAI/gpt-j-6B", | |
| "CodeParrot": "codeparrot/codeparrot", | |
| "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono", | |
| } | |
| def generation(tokenizer, model, content): | |
| input_ids = tokenizer.encode(content, return_tensors='pt') | |
| decoder = 'Standard' | |
| num_beams = 2 if decoder == 'Beam' else None | |
| typical_p = 0.8 if decoder == 'Typical' else None | |
| do_sample = (decoder in ['Beam', 'Typical', 'Sample']) | |
| typ_output = model.generate( | |
| input_ids, | |
| max_length=120, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| do_sample=do_sample, | |
| typical_p=typical_p, | |
| repetition_penalty=4.0, | |
| ) | |
| txt = tokenizer.decode(typ_output[0], skip_special_tokens=True) | |
| return txt | |
| def code_from_prompts(prompt, model, type_hints): | |
| tokenizer = AutoTokenizer.from_pretrained(modelPath[model]) | |
| model = AutoModelForCausalLM.from_pretrained(modelPath[model]) | |
| code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt) | |
| if type_hints: | |
| code = code.replace('id,', 'id: int,') | |
| code = code.replace('new_name)', 'new_name: str) -> None') | |
| results = [ | |
| generation(tokenizer, model, code), | |
| 0.5, | |
| ] | |
| del tokenizer | |
| del model | |
| return results | |
| iface = gr.Interface( | |
| fn=code_from_prompts, | |
| inputs=[ | |
| gr.inputs.Textbox(label="Insert comment"), | |
| gr.inputs.Radio(list(modelPath.keys()), label="Code Model"), | |
| gr.inputs.Checkbox(label="Include type hints") | |
| ], | |
| outputs=[ | |
| gr.outputs.Textbox(label="Generated code"), | |
| gr.outputs.Textbox(label="Probability"), | |
| ], | |
| description="Prompt the code model to write a SQL query with string concatenation.", | |
| ) | |
| iface.launch() |