File size: 2,154 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import List


RUN_CODE_TEMPLATE = """{code}
repr(f({inputs}))"""

RUN_CODE_TEMPLATE_REPR = """{code}
print('<FINAL_REPR_SYMBOL>', repr(f({inputs})))"""

VALIDATE_CODE_TEMPLATE = """{code}
repr(f({inputs}))"""

VALIDATE_CODE_TEMPLATE_REPR = """{code}
print('<FINAL_REPR_SYMBOL>', repr(f({inputs})))"""

EVAL_INPUT_PREDICTION_TEMPLATE = """{code}
{gold_output} == f({agent_input})"""

EVAL_INPUT_PREDICTION_TEMPLATE_REPR = """{code}
print('<FINAL_REPR_SYMBOL>', repr({gold_output} == f({agent_input})))"""

EVAL_OUTPUT_PREDICTION_TEMPLATE = """{code}
eval({gold_output}) == eval({agent_output})"""

EVAL_OUTPUT_PREDICTION_TEMPLATE_REPR = """{code}
print('<FINAL_REPR_SYMBOL>', repr(eval({gold_output}) == eval({agent_output})))"""

CHECK_DETERMINISM_TEMPLATE = """{code}
returns = f({inputs})
if returns != f({inputs}):
    raise Exception('Non-deterministic code')
repr(returns)"""

CHECK_DETERMINISM_TEMPLATE_REPR = """{code}
returns = f({inputs})
if returns != f({inputs}):
    raise Exception('Non-deterministic code')
print('<FINAL_REPR_SYMBOL>', repr(returns))"""

def EVAL_K_INPUT_PREDICTION_TEMPLATE(code: str, gold_output: str, k_agent_inputs: List[str], repr_output: bool = False):
    output_string = f"""{code}
acc_list = []"""
    for inp in k_agent_inputs:
        output_string += f"""\ntry:
    acc_list.append({gold_output} == f({inp}))
except:
    acc_list.append(False)"""
    # then compute the mean of the list
    if repr_output:
        output_string += """\nprint('<FINAL_REPR_SYMBOL>', repr(acc_list))"""
    else:
        output_string += """\nacc_list"""
    return output_string

def EVAL_K_OUTPUT_PREDICTION_TEMPLATE(code: str, gold_output: str, k_agent_outputs: List[str], repr_output: bool = False):
    output_string = f"""{code}
acc_list = []"""
    for out in k_agent_outputs:
        output_string += f"""\ntry:
    acc_list.append({gold_output} == {out})
except:
    acc_list.append(False)"""
    # then compute the mean of the list
    if repr_output:
        output_string += """\nprint('<FINAL_REPR_SYMBOL>', repr(acc_list))"""
    else:
        output_string += """\nacc_list"""
    return output_string