File size: 2,331 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys

sys.path.append("../../")
import argparse

from sotopia_rl.prompter.attribution_prompting import (
    parallel_generate_reward_attribution,
)
from sotopia_rl.utils.preprocess import add_score


def main(data_dir: str,
        llm_name: str,
        input_file: str,
        output_file: str,
        attribution_method_name: str,
        attribution_instruction_name: str,
        max_concurrency: int = 1
    ) -> None:

    add_score(
        data_dir,
        input_file,
        "sotopia_pi_episodes_with_scores.jsonl",
    )
    parallel_generate_reward_attribution(
        data_dir,
        llm_name=llm_name,
        input_file="sotopia_pi_episodes_with_scores.jsonl",
        output_file=output_file,
        attribution_method_name=attribution_method_name,
        attribution_instruction_name=attribution_instruction_name,
        max_concurrency=max_concurrency
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some integers.")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Directory containing data files",
    )
    parser.add_argument(
        "--llm_name",
        type=str,
        required=True,
        help="Name of the language model",
    )
    parser.add_argument(
        "--attribution_method_name",
        type=str,
        required=False,
        help="Type of attribution method",
    )
    parser.add_argument(
        "--attribution_instruction_name",
        type=str,
        required=False,
        help="Type of attribution instruction",
    )
    parser.add_argument(
        "--input_file",
        type=str,
        required=True,
        help="Input file containing episodes",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        required=True,
        help="Output file containing episodes with reward attribution",
    )
    parser.add_argument(
        "--max_concurrency",
        type=int,
        required=False,
        default=1,
        help="Maximum number of concurrent episodes",
    )


    args = parser.parse_args()

    main(args.data_dir,
        args.llm_name,
        args.input_file,
        args.output_file,
        args.attribution_method_name,
        args.attribution_instruction_name,
        args.max_concurrency
        )