File size: 9,595 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import json
import os
from typing import Any, Dict, List

import transformers
from sotopia.database.logs import EpisodeLog

# PROMPT_PREFIX = "Prompt after formatting:\n"
MAX_TOKEN = 2048  # 5000

PROMPT_TEMPLATE = """Prompt after formatting:\nImagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
You can find {agent}'s background and goal in the 'Here is the context of the interaction' field.
Note that {agent}'s secret and goal is only visible to you.
You should try your best to achieve {agent}'s goal in a way that align with their character traits.
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
{history}.
You are at Turn #{turn_number}."""

# PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

FORMAT_TEMPLATE = """ Your available action types are
"none action speak non-verbal communication leave".
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.

Please only generate a JSON string including the action type and the argument.
Your action should follow the given format:
\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}
the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.
\nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m
"""
# static
ACTION_LIST = "none action speak non-verbal communication leave"  # " ".join(ActionType)

ACTION_REVERSE_MAP = {"left ": "leave", "did n": "none", "said:": "speak"}

MODEL_CHECKPOINT = "meta-llama/Llama-2-13b-chat-hf"


TOKENIZER = transformers.AutoTokenizer.from_pretrained(
    MODEL_CHECKPOINT,
    padding=False,
    truncation=False,
)


def to_natural_language(self: Any) -> str:
    match self.action_type:
        case "none":
            return "did nothing"
        case "speak":
            return f'said: "{self.argument}"'
        case "non-verbal communication":
            return f"[{self.action_type}] {self.argument}"
        case "action":
            return f"[{self.action_type}] {self.argument}"
        case "leave":
            return "left the conversation"
    return "did nothing"


SELECTED_TAG = ["gpt-4_gpt-4_v0.0.1_clean"]


def detect_action(msg: str) -> str:
    # first detect what action type is, default at none
    if msg.startswith("said:"):
        action = "speak"
    elif msg.startswith("left"):
        action = "leave"
    elif msg.startswith("[non-verbal communication]"):
        action = "non-verbal communication"
    elif msg.startswith("[action]"):
        action = "action"
    else:
        action = "none"

    return action


def generate_result(msg: str) -> str:
    action = detect_action(msg)
    result = {}
    result["action_type"] = action
    result["argument"] = ""
    # know formating argument based on action type
    match action:
        case "speak":
            # NOTE: this assume that the speech is in quotes, not ending without punctuation
            result["argument"] = msg.replace("said: ", "")[1:-1]
        case "action":
            result["argument"] = msg
        case "non-verbal communication":
            result["argument"] = msg

    str_result = json.dumps(result)

    return str_result


def surpass_max_token_check(string: str, max_token: int=MAX_TOKEN, tokenizer: transformers.AutoTokenizer=TOKENIZER) -> int:
    prompt_tokens = len(tokenizer(string)["input_ids"])
    return max(prompt_tokens - max_token, 0)


def truncate_prompt_to_length(dia_his: str, surpass_num: int, tokenizer: transformers.AutoTokenizer=TOKENIZER) -> str:
    # context_len = len(tokenizer(context)['input_ids'])
    dia_sen = dia_his.split("\n")
    remove_len = 0
    i = 0
    while remove_len < surpass_num:
        remove_len += len(tokenizer(dia_sen[i])["input_ids"])
        i += 1
    trunc_dia = "\n".join(p for p in dia_sen[i:])
    return trunc_dia


def reverse_episode_log(
    epilog: EpisodeLog, later_speak: bool=False, include_format: bool=True, max_token: int=MAX_TOKEN
) -> List[Dict[str, Any]]:
    episode_msg = epilog.messages
    # per episode
    if not epilog.models:
        raise Exception("No models recorded in the episode log")

    agent_model = epilog.models[1] if not later_speak else epilog.models[2]
    promt_template = PROMPT_TEMPLATE

    if len(episode_msg) > 0:
        init_loop = episode_msg[0]
        # figure out who speak later, as we must use the 2nd player's data, else turn 0 have nothing to predict the beginning
        if later_speak:
            speaker = init_loop[-1][0]  # this would be the agent as well
            turn_div = 1
        # figure out who speak the first
        else:
            speaker = init_loop[-2][0]
            turn_div = 0

    prompt_result_instances = []
    dial_history = ""
    history = []
    for i in range(0, len(episode_msg)):
        msg = episode_msg[i]
        if (len(msg) != 4) and i < (len(episode_msg) - 1):
            continue
        turn_dic = {"model": agent_model, "env_id": epilog.environment, "agent_ids": epilog.agents}
        for tpl in msg:
            if tpl[0] == "Environment" and (tpl[1] == speaker):
                if i > 0:
                    dial_history += "\n" + tpl[2]
                else:
                    # for the first context, we don't need \n
                    context = tpl[2]
                    dial_history += context

            if tpl[0] == speaker and i % 2 == turn_div:
                history.append(f"Utterance {(i - 1) // 2} by {tpl[0]} " + tpl[2])

            if tpl[0] != "Environment" and tpl[0] != speaker and i % 2 != turn_div:
                history.append(f"Utterance {(i - 1) // 2} by {tpl[0]} " + tpl[2])

            if tpl[0] == speaker:  # if speaker is the agent, use what he said as result
                str_result = generate_result(tpl[2])
                # check if this is the end
        if i % 2 == turn_div:
            # take alternative turns as we always want to predict one agent, not both
            next_turn = i
            prompt = promt_template.format(
                agent=speaker, history=dial_history, turn_number=next_turn
            )
            over_tokens = surpass_max_token_check(prompt, max_token)
            if over_tokens > 0:
                all_dial = dial_history[len(context) :]
                trun_dial = truncate_prompt_to_length(all_dial, over_tokens)
                prompt = promt_template.format(
                    agent=speaker,
                    history=context + "\n" + trun_dial,
                    turn_number=next_turn,
                )
            if include_format:
                prompt += FORMAT_TEMPLATE
            turn_dic["prompt"] = prompt
            turn_dic["result"] = str_result
            turn_dic["history"] = list(history[1:])
            turn_dic["speaker"] = speaker
            prompt_result_instances.append(turn_dic)

    return prompt_result_instances


def concat_episode_msg(epilog: EpisodeLog) -> str:
    episode_msg = epilog.messages
    # per episode

    if len(episode_msg) > 0:
        init_loop = episode_msg[0]
        speaker = init_loop[-2][0]
    dial_history = ""

    for i in range(0, len(episode_msg)):
        msg = episode_msg[i]
        if (len(msg) != 4) and i < (len(episode_msg) - 1):
            continue
        for tpl in msg:
            if tpl[0] == "Environment" and (tpl[1] == speaker):
                if i > 0:
                    dial_history += "\n" + tpl[2]
                else:
                    # for the first context, we don't need \n
                    context = tpl[2]
                    dial_history += context

    return dial_history


def parse_prompt_to_json(episode: EpisodeLog, dir: str, init_speak: bool, include_format: bool=False) -> None:
    prompt_result_instances = reverse_episode_log(episode, init_speak, include_format)

    if not os.path.exists(dir):
        os.makedirs(dir)

    for i in range(len(prompt_result_instances)):
        instance = prompt_result_instances[i]
        todump = json.dumps(instance, indent=4)
        with open(dir + "/{}-{}-{}.json".format(episode.pk, instance['speaker'], i), "w") as f:
            f.write(todump)


def run_reverse_by_pk_agent(episode_pk: str, agent_side: bool, save_dir: str) -> None:
    """
    Entry function if you want to reverse engineer given a pk, not a episode
    """
    episode = EpisodeLog.find(EpisodeLog.pk == episode_pk).all()[0]
    if not episode:
        raise Exception(f"Episode {episode_pk} not found")
    parse_prompt_to_json(episode, save_dir, agent_side, False)