heendung's picture
Upload folder using huggingface_hub
d1c897a verified
import base64, copy
import os
import time
import re
from openai import OpenAI
from eval_agent.prompt import prompt_naive, prompt_with_icl
from .prompts import ALFWORLD_GOAL_SYSTEM, ALFWORLD_LemonTea_SYSTEM
from .hf_utils import history_to_sft_sample
from abc import ABC, abstractmethod
def remove_number(goal):
clean = re.sub(r'\b\d+\b', '', goal) # remove whole numbers
clean = " ".join(clean.split()) # collapse any extra spaces
return clean
class BaseJuicer(ABC):
"""Base class for relabeling strategies using an OpenAI client."""
def __init__(self, llama: str = "llama-3-1-70b", api: str = "internal") -> None:
if api == "internal":
endpoint = "http://pluto-prod-hawang-llm-proxy-9qtfav-0:4000"
key = "sk-QObXQcx0GTDciGVNhkgTLw"
api_key = "Bearer " + key
elif api == "openrouter":
endpoint = "https://openrouter.ai/api/v1"
api_key = os.environ["OPENROUTER_API_KEY"]
else:
raise ValueError(f"Unknown API option: {api}")
self.client = OpenAI(api_key=api_key, base_url=endpoint)
self.llama = llama
def _chat_completion(self, messages):
"""Invoke the OpenAI client with retry logic."""
attempt = 0
while True:
try:
resp = self.client.chat.completions.create(
model=self.llama, messages=messages
)
return resp.choices[0].message.content
except Exception as exc: # pragma: no cover - network failures
attempt += 1
if attempt >= 3:
print(f"Failed to send request: {exc}")
raise
time.sleep(1)
@abstractmethod
def relabel_experience(self, state_history, obs, llm_out):
"""Return relabeled trajectory based on ``obs`` and ``llm_out``."""
raise NotImplementedError
class Juicer(BaseJuicer):
def __init__(self, task_instruction, llama="llama-3-1-70b", api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
def relabel_experience(self, state_history, obs, llm_out):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": ALFWORLD_GOAL_SYSTEM},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp)
if False:
print()
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("-"*100)
print()
if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower():
new_goal = chat_response.split("Final goal: ")[-1]
new_goal = remove_number(new_goal)
else:
return {
"has_hs": False,
"hs": chat_response
} #new_goal = chat_response.split("\n")[-1]
new_goal, _ = prompt_naive(self.task_instruction, new_goal)
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = new_goal
if False:
print("*"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj[0])
print("*"*100)
return {
'has_hs': True,
"hs": new_traj
}
class Lemonade(BaseJuicer):
'''
one sigle goal for each trajectory, mask out the irrelevant actions
first_only: only include the first achieved goal
'''
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.first_only = first_only
def distill(self, new_trajectory):
'''
re-examine the trajectory with relabled goal
'''
def relabel_experience(self, state_history, obs, llm_out):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": ALFWORLD_GOAL_SYSTEM},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp)
if False:
print()
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("-"*100)
print()
if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower():
new_goal = chat_response.split("Final goal: ")[-1]
new_goal = remove_number(new_goal)
else:
return {
"has_hs": False,
"hs": chat_response
} #new_goal = chat_response.split("\n")[-1]
new_goal, _ = prompt_naive(self.task_instruction, new_goal)
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = new_goal
if False:
print("*"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj[0])
print("*"*100)
return {
'has_hs': True,
"hs": new_traj
}
class Lemontea(BaseJuicer):
'''
multiple goals for each trajectory
mask_out: mask out the actions irrelevant to any goal
learn_explore: the goal space includes the random explroation
'''
def __init__(self, task_instruction, llama="llama-3-1-70b", mask_out=False, learn_explore=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.mask_out = mask_out
self.learn_explore = learn_explore
def distill(self, new_trajectory):
'''
re-examine the trajectory with relabled goal
'''
def relabel_experience(self, state_history, obs, llm_out):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": ALFWORLD_GOAL_SYSTEM},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp)
if False:
print()
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("-"*100)
print()
if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower():
new_goal = chat_response.split("Final goal: ")[-1]
new_goal = remove_number(new_goal)
else:
return {
"has_hs": False,
"hs": chat_response
} #new_goal = chat_response.split("\n")[-1]
new_goal, _ = prompt_naive(self.task_instruction, new_goal)
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = new_goal
if False:
print("*"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj[0])
print("*"*100)
return {
'has_hs': True,
"hs": new_traj
}