| from prompt import prompt |
| import re |
| from utils import generate_response, run_code |
|
|
|
|
| def post_process_code(code, question): |
| parameters = code.split("\n")[0].split("def solution")[-1][1:-2].split(",") |
| if '' in parameters: |
| parameters.remove('') |
| values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)] |
| values = [int(v) for v in values] |
| return list(zip(parameters, values)) |
|
|
|
|
| def solve_pal(question, token): |
| question = question.strip() |
| query = prompt.format(question=question).strip() |
| code = generate_response(query, 0.9, token) |
| code = code.split("def solution():")[-1].strip() |
| code = "def solution():\n" + code |
| |
|
|
| arguments = post_process_code(code, question) |
|
|
| arg_string = "" |
| for param, val in arguments: |
| arg_string += f"{param}={val}," |
| func_call = f"\nprint(solution({arg_string[:-1]}))" |
| code += func_call |
| |
| if "input(" in code: |
| return None, code |
| pred = None |
| try: |
| pred = run_code(code) |
| except Exception as ex: |
| return None, code |
| return pred, code |
|
|
|
|
| if __name__ == "__main__": |
|
|
| q = "What is the 7th Fibonacci number?" |
| print(solve_pal(q)) |
|
|
|
|