File size: 4,066 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import json
import re


def overwrite_eval_bash(
    eval_script: str,
    tag: str,
    env_ids: list,
    batch_size: int = 2,
    agent1_model: str = "gpt-3.5-turbo",
    agent2_model: str = "gpt-3.5-turbo",
    push_to_db: bool = True,
) -> None:

    with open(eval_script, "r") as f:
        lines = f.readlines()

    for i in range(len(lines)):
        # change TAG, TAG_TO_CHECK_EXISTING_EPISODES
        if "--gin.TAG_TO_CHECK_EXISTING_EPISODES" in lines[i]:
            pattern = (
                r'(--gin\.TAG_TO_CHECK_EXISTING_EPISODES=")([^"]*)(".*\n)'
            )
            lines[i] = re.sub(pattern, r"\1" + tag + r"\3", lines[i])
        elif "--gin.TAG" in lines[i]:
            pattern = r'(--gin\.TAG=")([^"]*)(".*\n)'
            lines[i] = re.sub(pattern, r"\1" + tag + r"\3", lines[i])
        # change ENV_IDS
        elif "--gin.ENV_IDS" in lines[i]:
            pattern = r"(--gin\.ENV_IDS=).*?(\s*\\)"
            lines[i] = re.sub(
                pattern, r"\1" + json.dumps(env_ids) + r"'" + r"\2", lines[i]
            )
        # change batch size
        elif "--gin.BATCH_SIZE" in lines[i]:
            pattern = r"(--gin\.BATCH_SIZE=)(\d+)"
            lines[i] = re.sub(pattern, r"\g<1>" + str(batch_size), lines[i])
        # change agent models
        elif "--gin.AGENT1_MODEL" in lines[i]:
            pattern = r'(--gin\.AGENT1_MODEL=")([^"]*)(".*\n)'
            lines[i] = re.sub(pattern, r"\1" + agent1_model + r"\3", lines[i])
        elif "--gin.AGENT2_MODEL" in lines[i]:
            pattern = r'(--gin\.AGENT2_MODEL=")([^"]*)(".*\n)'
            lines[i] = re.sub(pattern, r"\1" + agent2_model + r"\3", lines[i])
        # change push to db flag
        elif "--gin.PUSH_TO_DB" in lines[i]:
            pattern = r"(--gin\.PUSH_TO_DB=)(True|False)"
            lines[i] = re.sub(pattern, r"\g<1>" + str(push_to_db), lines[i])

    with open(eval_script, "w") as f:
        f.write("".join(lines))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--eval-script",
        type=str,
        required=True,
        help="Required. Provide template bash file for sotopia evaluation.",
    )
    parser.add_argument(
        "--env-file",
        type=str,
        default="env_files/used_env.json",
        help="Default: env_files/used_env.json. Provide the json file of env ids for conversation generation.",
    )
    parser.add_argument(
        "--experiment-name",
        type=str,
        required=True,
        help="Required. Need the experiment_name, which is the key of the env_file.",
    )
    parser.add_argument(
        "--tag",
        type=str,
        required=True,
        help="Required. Provide a unique tag that will be pushed to REDIS database.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=2,
        help="Default: 2. Provide the batch size of calling APIs.",
    )
    parser.add_argument(
        "--agent1-model",
        type=str,
        default="gpt-3.5-turbo",
        help="Default: gpt-3.5-turbo. Provide the name of OPENAI model.",
    )
    parser.add_argument(
        "--agent2-model",
        type=str,
        default="gpt-3.5-turbo",
        help="Default: gpt-3.5-turbo. Provide the name of OPENAI model.",
    )
    parser.add_argument(
        "--push-to-db",
        type=str,
        default=True,
        help="Default: True. If you choose False, then the conversations will not be pushed to REDIS database.",
    )
    args = parser.parse_args()

    with open(args.env_file, "r") as f:
        env_ids = json.loads(f.read())[args.experiment_name]

    overwrite_eval_bash(
        eval_script=args.eval_script,
        tag=args.tag,
        env_ids=env_ids,
        batch_size=args.batch_size,
        agent1_model=args.agent1_model,
        agent2_model=args.agent2_model,
        push_to_db=True if args.push_to_db == "True" else False,
    )

    # command = f"bash {args.eval_script}"
    # subprocess.run(command.split())


if __name__ == "__main__":
    main()