| |
|
| | import base64, copy |
| | import os |
| |
|
| | import time |
| | import re |
| | from openai import OpenAI |
| | from eval_agent.prompt import prompt_naive, prompt_with_icl |
| | from .prompts import ALFWORLD_GOAL_SYSTEM, ALFWORLD_LemonTea_SYSTEM |
| | from .hf_utils import history_to_sft_sample |
| | from abc import ABC, abstractmethod |
| |
|
| | def remove_number(goal): |
| | clean = re.sub(r'\b\d+\b', '', goal) |
| | clean = " ".join(clean.split()) |
| | return clean |
| |
|
| | class BaseJuicer(ABC): |
| | """Base class for relabeling strategies using an OpenAI client.""" |
| |
|
| | def __init__(self, llama: str = "llama-3-1-70b", api: str = "internal") -> None: |
| | if api == "internal": |
| | endpoint = "http://pluto-prod-hawang-llm-proxy-9qtfav-0:4000" |
| | key = "sk-QObXQcx0GTDciGVNhkgTLw" |
| | api_key = "Bearer " + key |
| | elif api == "openrouter": |
| | endpoint = "https://openrouter.ai/api/v1" |
| | api_key = os.environ["OPENROUTER_API_KEY"] |
| | else: |
| | raise ValueError(f"Unknown API option: {api}") |
| |
|
| | self.client = OpenAI(api_key=api_key, base_url=endpoint) |
| | self.llama = llama |
| |
|
| | def _chat_completion(self, messages): |
| | """Invoke the OpenAI client with retry logic.""" |
| | attempt = 0 |
| | while True: |
| | try: |
| | resp = self.client.chat.completions.create( |
| | model=self.llama, messages=messages |
| | ) |
| | return resp.choices[0].message.content |
| | except Exception as exc: |
| | attempt += 1 |
| | if attempt >= 3: |
| | print(f"Failed to send request: {exc}") |
| | raise |
| | time.sleep(1) |
| |
|
| | @abstractmethod |
| | def relabel_experience(self, state_history, obs, llm_out): |
| | """Return relabeled trajectory based on ``obs`` and ``llm_out``.""" |
| | raise NotImplementedError |
| |
|
| |
|
| | class Juicer(BaseJuicer): |
| | def __init__(self, task_instruction, llama="llama-3-1-70b", api="internal"): |
| | super().__init__(llama=llama, api=api) |
| | self.task_instruction = task_instruction |
| |
|
| | def relabel_experience(self, state_history, obs, llm_out): |
| | act_obs_traj = '' |
| | for i, (x, y) in enumerate(zip(obs, llm_out)): |
| | if x.startswith("Observation: Error Input."): |
| | continue |
| | action = y.split("Action: ")[-1] |
| | action = f'Action="{action}"' |
| | obs = x.split("Observation: ")[-1] |
| | obs = f'Observation="{obs}"' |
| | act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
| |
|
| | |
| | |
| | |
| | relabel_inp = [ |
| | {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, |
| | {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| | ] |
| | chat_response = self._chat_completion(relabel_inp) |
| |
|
| | if False: |
| | print() |
| | print("-"*100) |
| | print("Input:\n") |
| | print(relabel_inp) |
| | print() |
| | print("Output:\n") |
| | print(chat_response) |
| | print() |
| | print("Original goal:\n") |
| | print(state_history[0]['content']) |
| | print("-"*100) |
| | print() |
| | if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): |
| | new_goal = chat_response.split("Final goal: ")[-1] |
| | new_goal = remove_number(new_goal) |
| | else: |
| | return { |
| | "has_hs": False, |
| | "hs": chat_response |
| | } |
| | |
| | new_goal, _ = prompt_naive(self.task_instruction, new_goal) |
| | new_traj = copy.deepcopy(state_history) |
| | new_traj[0]['content'] = new_goal |
| |
|
| | if False: |
| | print("*"*100) |
| | print("OLD TRAJ:\n") |
| | print(state_history) |
| | print() |
| | print("NEW TRAJ:\n") |
| | print(new_traj[0]) |
| | print("*"*100) |
| | return { |
| | 'has_hs': True, |
| | "hs": new_traj |
| | } |
| |
|
| | class Lemonade(BaseJuicer): |
| | ''' |
| | one sigle goal for each trajectory, mask out the irrelevant actions |
| | first_only: only include the first achieved goal |
| | ''' |
| | def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| | super().__init__(llama=llama, api=api) |
| | self.task_instruction = task_instruction |
| | self.first_only = first_only |
| |
|
| | def distill(self, new_trajectory): |
| | ''' |
| | re-examine the trajectory with relabled goal |
| | ''' |
| | |
| |
|
| | def relabel_experience(self, state_history, obs, llm_out): |
| | act_obs_traj = '' |
| | for i, (x, y) in enumerate(zip(obs, llm_out)): |
| | if x.startswith("Observation: Error Input."): |
| | continue |
| | action = y.split("Action: ")[-1] |
| | action = f'Action="{action}"' |
| | obs = x.split("Observation: ")[-1] |
| | obs = f'Observation="{obs}"' |
| | act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
| |
|
| | |
| | |
| | |
| | relabel_inp = [ |
| | {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, |
| | {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| | ] |
| | chat_response = self._chat_completion(relabel_inp) |
| |
|
| | if False: |
| | print() |
| | print("-"*100) |
| | print("Input:\n") |
| | print(relabel_inp) |
| | print() |
| | print("Output:\n") |
| | print(chat_response) |
| | print() |
| | print("Original goal:\n") |
| | print(state_history[0]['content']) |
| | print("-"*100) |
| | print() |
| | if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): |
| | new_goal = chat_response.split("Final goal: ")[-1] |
| | new_goal = remove_number(new_goal) |
| | else: |
| | return { |
| | "has_hs": False, |
| | "hs": chat_response |
| | } |
| | |
| | new_goal, _ = prompt_naive(self.task_instruction, new_goal) |
| | new_traj = copy.deepcopy(state_history) |
| | new_traj[0]['content'] = new_goal |
| |
|
| | if False: |
| | print("*"*100) |
| | print("OLD TRAJ:\n") |
| | print(state_history) |
| | print() |
| | print("NEW TRAJ:\n") |
| | print(new_traj[0]) |
| | print("*"*100) |
| | return { |
| | 'has_hs': True, |
| | "hs": new_traj |
| | } |
| | |
| |
|
| | class Lemontea(BaseJuicer): |
| | ''' |
| | multiple goals for each trajectory |
| | mask_out: mask out the actions irrelevant to any goal |
| | learn_explore: the goal space includes the random explroation |
| | ''' |
| | def __init__(self, task_instruction, llama="llama-3-1-70b", mask_out=False, learn_explore=False, api="internal"): |
| | super().__init__(llama=llama, api=api) |
| | self.task_instruction = task_instruction |
| | self.mask_out = mask_out |
| | self.learn_explore = learn_explore |
| |
|
| | def distill(self, new_trajectory): |
| | ''' |
| | re-examine the trajectory with relabled goal |
| | ''' |
| | |
| |
|
| | def relabel_experience(self, state_history, obs, llm_out): |
| | act_obs_traj = '' |
| | for i, (x, y) in enumerate(zip(obs, llm_out)): |
| | if x.startswith("Observation: Error Input."): |
| | continue |
| | action = y.split("Action: ")[-1] |
| | action = f'Action="{action}"' |
| | obs = x.split("Observation: ")[-1] |
| | obs = f'Observation="{obs}"' |
| | act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
| |
|
| | |
| | |
| | |
| | relabel_inp = [ |
| | {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, |
| | {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| | ] |
| | chat_response = self._chat_completion(relabel_inp) |
| |
|
| | if False: |
| | print() |
| | print("-"*100) |
| | print("Input:\n") |
| | print(relabel_inp) |
| | print() |
| | print("Output:\n") |
| | print(chat_response) |
| | print() |
| | print("Original goal:\n") |
| | print(state_history[0]['content']) |
| | print("-"*100) |
| | print() |
| | if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): |
| | new_goal = chat_response.split("Final goal: ")[-1] |
| | new_goal = remove_number(new_goal) |
| | else: |
| | return { |
| | "has_hs": False, |
| | "hs": chat_response |
| | } |
| | |
| | new_goal, _ = prompt_naive(self.task_instruction, new_goal) |
| | new_traj = copy.deepcopy(state_history) |
| | new_traj[0]['content'] = new_goal |
| |
|
| | if False: |
| | print("*"*100) |
| | print("OLD TRAJ:\n") |
| | print(state_history) |
| | print() |
| | print("NEW TRAJ:\n") |
| | print(new_traj[0]) |
| | print("*"*100) |
| | return { |
| | 'has_hs': True, |
| | "hs": new_traj |
| | } |