basketball_code / scripts /data_process /db_free_reverse_engineering.py
youqiwong's picture
Upload folder using huggingface_hub
0c51b93 verified
import json
import os
from typing import Any, Dict, List
import transformers
from episode_utils import FakeEpisodeLog, jsonl_to_episodes
# PROMPT_PREFIX = "Prompt after formatting:\n"
MAX_TOKEN = 2048 # 5000
PROMPT_TEMPLATE = """Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
You can find {agent}'s background and goal in the 'Here is the context of the interaction' field.
Note that {agent}'s secret and goal is only visible to you.
You should try your best to achieve {agent}'s goal in a way that align with their character traits.
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
{history}.
You are at Turn #{turn_number}."""
# PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
FORMAT_TEMPLATE = """ Your available action types are
"none action speak non-verbal communication leave".
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
Please only generate a JSON string including the action type and the argument.
Your action should follow the given format:
\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}
the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.
\nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m
"""
# static
ACTION_LIST = "none action speak non-verbal communication leave" # " ".join(ActionType)
ACTION_REVERSE_MAP = {"left ": "leave", "did n": "none", "said:": "speak"}
MODEL_CHECKPOINT = "meta-llama/Llama-2-13b-chat-hf"
TOKENIZER = transformers.AutoTokenizer.from_pretrained(
MODEL_CHECKPOINT,
padding=False,
truncation=False,
)
def to_natural_language(self: Any) -> str:
match self.action_type:
case "none":
return "did nothing"
case "speak":
return f'said: "{self.argument}"'
case "non-verbal communication":
return f"[{self.action_type}] {self.argument}"
case "action":
return f"[{self.action_type}] {self.argument}"
case "leave":
return "left the conversation"
return "did nothing"
SELECTED_TAG = ["gpt-4_gpt-4_v0.0.1_clean"]
def detect_action(msg: str) -> str:
# first detect what action type is, default at none
if msg.startswith("said:"):
action = "speak"
elif msg.startswith("left"):
action = "leave"
elif msg.startswith("[non-verbal communication]"):
action = "non-verbal communication"
elif msg.startswith("[action]"):
action = "action"
else:
action = "none"
return action
def generate_result(msg: str) -> str:
action = detect_action(msg)
result = {}
result["action_type"] = action
result["argument"] = ""
# know formating argument based on action type
match action:
case "speak":
# NOTE: this assume that the speech is in quotes, not ending without punctuation
result["argument"] = msg.replace("said: ", "")[1:-1]
case "action":
result["argument"] = msg
case "non-verbal communication":
result["argument"] = msg
str_result = json.dumps(result)
return str_result
def surpass_max_token_check(string: str, max_token: int=MAX_TOKEN, tokenizer: transformers.AutoTokenizer=TOKENIZER) -> int:
prompt_tokens = len(tokenizer(string)["input_ids"])
return max(prompt_tokens - max_token, 0)
def truncate_prompt_to_length(dia_his: str, surpass_num: int, tokenizer: transformers.AutoTokenizer=TOKENIZER) -> str:
# context_len = len(tokenizer(context)['input_ids'])
dia_sen = dia_his.split("\n")
remove_len = 0
i = 0
while remove_len < surpass_num:
remove_len += len(tokenizer(dia_sen[i])["input_ids"])
i += 1
trunc_dia = "\n".join(p for p in dia_sen[i:])
return trunc_dia
def reverse_episode_log(
epilog: FakeEpisodeLog, later_speak: bool=False, include_format: bool=True, max_token: int=MAX_TOKEN
) -> List[Dict[str, Any]]:
episode_msg = epilog.messages
# per episode
if not epilog.models:
raise Exception("No models recorded in the episode log")
agent_model = epilog.models[1] if not later_speak else epilog.models[2]
promt_template = PROMPT_TEMPLATE
if len(episode_msg) > 0:
init_loop = episode_msg[0]
# figure out who speak later, as we must use the 2nd player's data, else turn 0 have nothing to predict the beginning
if later_speak:
speaker = init_loop[-1][0] # this would be the agent as well
turn_div = 1
# figure out who speak the first
else:
speaker = init_loop[-2][0]
turn_div = 0
prompt_result_instances = []
dial_history = ""
history = []
for i in range(0, len(episode_msg)):
msg = episode_msg[i]
if (len(msg) != 4) and i < (len(episode_msg) - 1):
continue
turn_dic = {"model": agent_model, "env_id": epilog.environment, "agent_ids": epilog.agents}
for tpl in msg:
if tpl[0] == "Environment" and (tpl[1] == speaker):
if i > 0:
dial_history += "\n" + tpl[2]
else:
# for the first context, we don't need \n
context = tpl[2]
dial_history += context
if tpl[0] == speaker and i % 2 == turn_div:
history.append(f"Utterance {(i - 1) // 2} by {tpl[0]} " + tpl[2])
if tpl[0] != "Environment" and tpl[0] != speaker and i % 2 != turn_div:
history.append(f"Utterance {(i - 1) // 2} by {tpl[0]} " + tpl[2])
if tpl[0] == speaker: # if speaker is the agent, use what he said as result
str_result = generate_result(tpl[2])
# check if this is the end
if i % 2 == turn_div:
# take alternative turns as we always want to predict one agent, not both
next_turn = i
prompt = promt_template.format(
agent=speaker, history=dial_history, turn_number=next_turn
)
over_tokens = surpass_max_token_check(prompt, max_token)
if over_tokens > 0:
all_dial = dial_history[len(context) :]
trun_dial = truncate_prompt_to_length(all_dial, over_tokens)
prompt = promt_template.format(
agent=speaker,
history=context + "\n" + trun_dial,
turn_number=next_turn,
)
if include_format:
prompt += FORMAT_TEMPLATE
turn_dic["prompt"] = prompt
turn_dic["result"] = str_result
turn_dic["history"] = list(history[1:])
turn_dic["speaker"] = speaker
prompt_result_instances.append(turn_dic)
return prompt_result_instances
def concat_episode_msg(epilog: FakeEpisodeLog) -> str:
episode_msg = epilog.messages
# per episode
if len(episode_msg) > 0:
init_loop = episode_msg[0]
speaker = init_loop[-2][0]
dial_history = ""
for i in range(0, len(episode_msg)):
msg = episode_msg[i]
if (len(msg) != 4) and i < (len(episode_msg) - 1):
continue
for tpl in msg:
if tpl[0] == "Environment" and (tpl[1] == speaker):
if i > 0:
dial_history += "\n" + tpl[2]
else:
# for the first context, we don't need \n
context = tpl[2]
dial_history += context
return dial_history
def parse_prompt_to_json(episode: FakeEpisodeLog, dir: str, init_speak: bool, include_format: bool=False) -> None:
prompt_result_instances = reverse_episode_log(episode, init_speak, include_format)
if not os.path.exists(dir):
os.makedirs(dir)
for i in range(len(prompt_result_instances)):
instance = prompt_result_instances[i]
todump = json.dumps(instance, indent=4)
with open(dir + "/{}-{}-{}.json".format(episode.pk, instance['speaker'], i), "w") as f:
f.write(todump)
def run_reverse_by_pk_agent(episode_pk: str, agent_side: bool, save_dir: str, episodes_file: str = "../../data/sotopia_pi_episodes.jsonl") -> None:
"""
Entry function if you want to reverse engineer given a pk, not a episode
"""
EPISODES = jsonl_to_episodes(episodes_file)
EPISODE_DICT = {ep.pk: ep for ep in EPISODES}
episode = EPISODE_DICT.get(episode_pk, None)
if not episode:
raise Exception(f"Episode {episode_pk} not found")
parse_prompt_to_json(episode, save_dir, agent_side, False)