| from prompt import TA_prompt |
| import re |
| from utils import generate_response, run_code |
|
|
|
|
| def post_process_code(code, question): |
| func_name = code.split("(")[0].split("def")[-1].strip() |
| parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",") |
| if '' in parameters: |
| parameters.remove('') |
| values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)] |
| values = [int(v) for v in values] |
| arguments = list(zip(parameters, values)) |
|
|
| arg_string = "" |
| for param, val in arguments: |
| arg_string += f"{param}={val}," |
| func_call = f"\nprint({func_name}({arg_string[:-1]}))" |
| code += func_call |
| return code |
|
|
|
|
| def solve_ta(question, token): |
| question = question.strip() |
| question = "Human: " + question |
| query = TA_prompt + question |
| query = query.strip() |
| query += "\n" |
| code = generate_response(query, 0.9, token) |
| n = len(TA_prompt.strip()) |
| code = code[n:].strip().split("-----")[0] |
| |
| splitting_string = "```" if "```python" not in code else "```python" |
| if "```" in code: |
| code = code.split(splitting_string)[1].split("```")[0].strip() |
| |
| code = post_process_code(code, question) |
| print(code) |
| |
| if "input(" in code: |
| return None, code |
| pred = None |
| try: |
| pred = run_code(code) |
| except Exception as ex: |
| return None, code |
| return pred, code |
| else: |
| res = re.findall(r"Assistant:(.*)", code)[0] |
| return res, "" |
|
|
|
|
| if __name__ == "__main__": |
|
|
| q = "What is the smallest even prime number?" |
| |
| print(solve_ta(q, "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")) |
|
|
|
|