|
|
import json |
|
|
import os |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import transformers |
|
|
from episode_utils import FakeEpisodeLog, jsonl_to_episodes |
|
|
|
|
|
|
|
|
MAX_TOKEN = 2048 |
|
|
|
|
|
PROMPT_TEMPLATE = """Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal. |
|
|
You can find {agent}'s background and goal in the 'Here is the context of the interaction' field. |
|
|
Note that {agent}'s secret and goal is only visible to you. |
|
|
You should try your best to achieve {agent}'s goal in a way that align with their character traits. |
|
|
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before). |
|
|
{history}. |
|
|
You are at Turn #{turn_number}.""" |
|
|
|
|
|
|
|
|
|
|
|
FORMAT_TEMPLATE = """ Your available action types are |
|
|
"none action speak non-verbal communication leave". |
|
|
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave. |
|
|
|
|
|
Please only generate a JSON string including the action type and the argument. |
|
|
Your action should follow the given format: |
|
|
\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]} |
|
|
the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted. |
|
|
\nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m |
|
|
""" |
|
|
|
|
|
ACTION_LIST = "none action speak non-verbal communication leave" |
|
|
|
|
|
ACTION_REVERSE_MAP = {"left ": "leave", "did n": "none", "said:": "speak"} |
|
|
|
|
|
MODEL_CHECKPOINT = "meta-llama/Llama-2-13b-chat-hf" |
|
|
|
|
|
EPISODES = jsonl_to_episodes("../../data/sotopia_pi_episodes.jsonl") |
|
|
|
|
|
EPISODE_DICT = {ep.pk: ep for ep in EPISODES} |
|
|
|
|
|
TOKENIZER = transformers.AutoTokenizer.from_pretrained( |
|
|
MODEL_CHECKPOINT, |
|
|
padding=False, |
|
|
truncation=False, |
|
|
) |
|
|
|
|
|
|
|
|
def to_natural_language(self: Any) -> str: |
|
|
match self.action_type: |
|
|
case "none": |
|
|
return "did nothing" |
|
|
case "speak": |
|
|
return f'said: "{self.argument}"' |
|
|
case "non-verbal communication": |
|
|
return f"[{self.action_type}] {self.argument}" |
|
|
case "action": |
|
|
return f"[{self.action_type}] {self.argument}" |
|
|
case "leave": |
|
|
return "left the conversation" |
|
|
return "did nothing" |
|
|
|
|
|
|
|
|
SELECTED_TAG = ["gpt-4_gpt-4_v0.0.1_clean"] |
|
|
|
|
|
|
|
|
def detect_action(msg: str) -> str: |
|
|
|
|
|
if msg.startswith("said:"): |
|
|
action = "speak" |
|
|
elif msg.startswith("left"): |
|
|
action = "leave" |
|
|
elif msg.startswith("[non-verbal communication]"): |
|
|
action = "non-verbal communication" |
|
|
elif msg.startswith("[action]"): |
|
|
action = "action" |
|
|
else: |
|
|
action = "none" |
|
|
|
|
|
return action |
|
|
|
|
|
|
|
|
def generate_result(msg: str) -> str: |
|
|
action = detect_action(msg) |
|
|
result = {} |
|
|
result["action_type"] = action |
|
|
result["argument"] = "" |
|
|
|
|
|
match action: |
|
|
case "speak": |
|
|
|
|
|
result["argument"] = msg.replace("said: ", "")[1:-1] |
|
|
case "action": |
|
|
result["argument"] = msg |
|
|
case "non-verbal communication": |
|
|
result["argument"] = msg |
|
|
|
|
|
str_result = json.dumps(result) |
|
|
|
|
|
return str_result |
|
|
|
|
|
|
|
|
def surpass_max_token_check(string: str, max_token: int=MAX_TOKEN, tokenizer: transformers.AutoTokenizer=TOKENIZER) -> int: |
|
|
prompt_tokens = len(tokenizer(string)["input_ids"]) |
|
|
return max(prompt_tokens - max_token, 0) |
|
|
|
|
|
|
|
|
def truncate_prompt_to_length(dia_his: str, surpass_num: int, tokenizer: transformers.AutoTokenizer=TOKENIZER) -> str: |
|
|
|
|
|
dia_sen = dia_his.split("\n") |
|
|
remove_len = 0 |
|
|
i = 0 |
|
|
while remove_len < surpass_num: |
|
|
remove_len += len(tokenizer(dia_sen[i])["input_ids"]) |
|
|
i += 1 |
|
|
trunc_dia = "\n".join(p for p in dia_sen[i:]) |
|
|
return trunc_dia |
|
|
|
|
|
|
|
|
def reverse_episode_log( |
|
|
epilog: FakeEpisodeLog, later_speak: bool=False, include_format: bool=True, max_token: int=MAX_TOKEN |
|
|
) -> List[Dict[str, Any]]: |
|
|
episode_msg = epilog.messages |
|
|
|
|
|
if not epilog.models: |
|
|
raise Exception("No models recorded in the episode log") |
|
|
|
|
|
agent_model = epilog.models[1] if not later_speak else epilog.models[2] |
|
|
promt_template = PROMPT_TEMPLATE |
|
|
|
|
|
if len(episode_msg) > 0: |
|
|
init_loop = episode_msg[0] |
|
|
|
|
|
if later_speak: |
|
|
speaker = init_loop[-1][0] |
|
|
turn_div = 1 |
|
|
|
|
|
else: |
|
|
speaker = init_loop[-2][0] |
|
|
turn_div = 0 |
|
|
|
|
|
prompt_result_instances = [] |
|
|
dial_history = "" |
|
|
history = [] |
|
|
for i in range(0, len(episode_msg)): |
|
|
msg = episode_msg[i] |
|
|
if (len(msg) != 4) and i < (len(episode_msg) - 1): |
|
|
continue |
|
|
turn_dic = {"model": agent_model, "env_id": epilog.environment, "agent_ids": epilog.agents} |
|
|
for tpl in msg: |
|
|
if tpl[0] == "Environment" and (tpl[1] == speaker): |
|
|
if i > 0: |
|
|
dial_history += "\n" + tpl[2] |
|
|
else: |
|
|
|
|
|
context = tpl[2] |
|
|
dial_history += context |
|
|
|
|
|
if tpl[0] == speaker and i % 2 == turn_div: |
|
|
history.append(f"Utterance {(i - 1) // 2} by {tpl[0]} " + tpl[2]) |
|
|
|
|
|
if tpl[0] != "Environment" and tpl[0] != speaker and i % 2 != turn_div: |
|
|
history.append(f"Utterance {(i - 1) // 2} by {tpl[0]} " + tpl[2]) |
|
|
|
|
|
if tpl[0] == speaker: |
|
|
str_result = generate_result(tpl[2]) |
|
|
|
|
|
if i % 2 == turn_div: |
|
|
|
|
|
next_turn = i |
|
|
prompt = promt_template.format( |
|
|
agent=speaker, history=dial_history, turn_number=next_turn |
|
|
) |
|
|
over_tokens = surpass_max_token_check(prompt, max_token) |
|
|
if over_tokens > 0: |
|
|
all_dial = dial_history[len(context) :] |
|
|
trun_dial = truncate_prompt_to_length(all_dial, over_tokens) |
|
|
prompt = promt_template.format( |
|
|
agent=speaker, |
|
|
history=context + "\n" + trun_dial, |
|
|
turn_number=next_turn, |
|
|
) |
|
|
if include_format: |
|
|
prompt += FORMAT_TEMPLATE |
|
|
turn_dic["prompt"] = prompt |
|
|
turn_dic["result"] = str_result |
|
|
turn_dic["history"] = list(history[1:]) |
|
|
turn_dic["speaker"] = speaker |
|
|
turn_dic["episode_id"] = epilog.pk |
|
|
prompt_result_instances.append(turn_dic) |
|
|
|
|
|
return prompt_result_instances |
|
|
|
|
|
|
|
|
def concat_episode_msg(epilog: FakeEpisodeLog) -> str: |
|
|
episode_msg = epilog.messages |
|
|
|
|
|
|
|
|
if len(episode_msg) > 0: |
|
|
init_loop = episode_msg[0] |
|
|
speaker = init_loop[-2][0] |
|
|
dial_history = "" |
|
|
|
|
|
for i in range(0, len(episode_msg)): |
|
|
msg = episode_msg[i] |
|
|
if (len(msg) != 4) and i < (len(episode_msg) - 1): |
|
|
continue |
|
|
for tpl in msg: |
|
|
if tpl[0] == "Environment" and (tpl[1] == speaker): |
|
|
if i > 0: |
|
|
dial_history += "\n" + tpl[2] |
|
|
else: |
|
|
|
|
|
context = tpl[2] |
|
|
dial_history += context |
|
|
|
|
|
return dial_history |
|
|
|
|
|
|
|
|
def parse_prompt_to_json(episode: FakeEpisodeLog, dir: str, init_speak: bool, include_format: bool=False) -> None: |
|
|
prompt_result_instances = reverse_episode_log(episode, init_speak, include_format) |
|
|
|
|
|
if not os.path.exists(dir): |
|
|
os.makedirs(dir) |
|
|
|
|
|
for i in range(len(prompt_result_instances)): |
|
|
instance = prompt_result_instances[i] |
|
|
todump = json.dumps(instance, indent=4) |
|
|
with open(dir + "/{}-{}-{}.json".format(episode.pk, instance['speaker'], i), "w") as f: |
|
|
f.write(todump) |
|
|
|
|
|
|
|
|
def run_reverse_by_pk_agent(episode_pk: str, agent_side: bool, save_dir: str) -> None: |
|
|
""" |
|
|
Entry function if you want to reverse engineer given a pk, not a episode |
|
|
""" |
|
|
episode = EPISODE_DICT.get(episode_pk, None) |
|
|
if not episode: |
|
|
raise Exception(f"Episode {episode_pk} not found") |
|
|
parse_prompt_to_json(episode, save_dir, agent_side, True) |
|
|
|