File size: 2,323 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
import os
from typing import Any, Dict


def select_qualifying_episodes(
    episodes: list[Dict[str, Any]]
) -> list[Dict[str, Any]]:
    num_turns = []
    len_episodes = []
    for episode in episodes:
        num_turns.append(episode["social_interactions"].count("said:"))
        len_episodes.append(len(episode["social_interactions"]))

    qualifying_episodes = []
    for episode in episodes:
        if (
            (
                episode["rewards"][0]["goal"] >= 8
                or episode["rewards"][0]["goal"] <= 2
            )
            and (
                episode["rewards"][1]["goal"] >= 8
                or episode["rewards"][1]["goal"] <= 2
            )
            and (episode["social_interactions"].count("said:") > 2)
            and (
                episode["experiment_model_name_pairs"][1] == "gpt-4"
                or episode["experiment_model_name_pairs"][1] == "gpt-3.5-turbo"
            )
            and (
                episode["experiment_model_name_pairs"][2] == "gpt-4"
                or episode["experiment_model_name_pairs"][2] == "gpt-3.5-turbo"
            )
        ):
            qualifying_episodes.append(episode)
    return qualifying_episodes


def create_non_repeating_sample_episodes(
    qualifying_episodes: list[Dict[str, Any]], num_episodes: int
) -> list[Dict[str, Any]]:
    if num_episodes == -1:
        return qualifying_episodes

    example_episodes = []
    visited_codename = set()
    for episode in qualifying_episodes:
        if episode["codename"] in visited_codename:
            continue

        example_episodes.append(episode)
        visited_codename.add(episode["codename"])

        if len(example_episodes) == num_episodes:
            break
    return example_episodes


def sample_episodes(data_dir: str, num_episodes: int = 30) -> None:
    with open(os.path.join(data_dir, "sotopia_episodes_v1.jsonl"), "r") as f:
        episodes = [json.loads(line) for line in f]

    qualifying_episodes = select_qualifying_episodes(episodes)
    example_episodes = create_non_repeating_sample_episodes(
        qualifying_episodes, num_episodes=num_episodes
    )

    with open(os.path.join(data_dir, "example_episodes.jsonl"), "w") as f:
        for episode in example_episodes:
            f.write(json.dumps(episode) + "\n")