import base64, copy import os import time import re from openai import OpenAI from eval_agent.prompt import prompt_naive, prompt_with_icl from .prompts import ALFWORLD_GOAL_SYSTEM, ALFWORLD_LemonTea_SYSTEM from .hf_utils import history_to_sft_sample from abc import ABC, abstractmethod def remove_number(goal): clean = re.sub(r'\b\d+\b', '', goal) # remove whole numbers clean = " ".join(clean.split()) # collapse any extra spaces return clean class BaseJuicer(ABC): """Base class for relabeling strategies using an OpenAI client.""" def __init__(self, llama: str = "llama-3-1-70b", api: str = "internal") -> None: if api == "internal": endpoint = "http://pluto-prod-hawang-llm-proxy-9qtfav-0:4000" key = "sk-QObXQcx0GTDciGVNhkgTLw" api_key = "Bearer " + key elif api == "openrouter": endpoint = "https://openrouter.ai/api/v1" api_key = os.environ["OPENROUTER_API_KEY"] else: raise ValueError(f"Unknown API option: {api}") self.client = OpenAI(api_key=api_key, base_url=endpoint) self.llama = llama def _chat_completion(self, messages): """Invoke the OpenAI client with retry logic.""" attempt = 0 while True: try: resp = self.client.chat.completions.create( model=self.llama, messages=messages ) return resp.choices[0].message.content except Exception as exc: # pragma: no cover - network failures attempt += 1 if attempt >= 3: print(f"Failed to send request: {exc}") raise time.sleep(1) @abstractmethod def relabel_experience(self, state_history, obs, llm_out): """Return relabeled trajectory based on ``obs`` and ``llm_out``.""" raise NotImplementedError class Juicer(BaseJuicer): def __init__(self, task_instruction, llama="llama-3-1-70b", api="internal"): super().__init__(llama=llama, api=api) self.task_instruction = task_instruction def relabel_experience(self, state_history, obs, llm_out): act_obs_traj = '' for i, (x, y) in enumerate(zip(obs, llm_out)): if x.startswith("Observation: Error Input."): continue action = y.split("Action: ")[-1] action = f'Action="{action}"' obs = x.split("Observation: ")[-1] obs = f'Observation="{obs}"' act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" #if act_obs_traj == '': # return None relabel_inp = [ {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, ] chat_response = self._chat_completion(relabel_inp) if False: print() print("-"*100) print("Input:\n") print(relabel_inp) print() print("Output:\n") print(chat_response) print() print("Original goal:\n") print(state_history[0]['content']) print("-"*100) print() if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): new_goal = chat_response.split("Final goal: ")[-1] new_goal = remove_number(new_goal) else: return { "has_hs": False, "hs": chat_response } #new_goal = chat_response.split("\n")[-1] new_goal, _ = prompt_naive(self.task_instruction, new_goal) new_traj = copy.deepcopy(state_history) new_traj[0]['content'] = new_goal if False: print("*"*100) print("OLD TRAJ:\n") print(state_history) print() print("NEW TRAJ:\n") print(new_traj[0]) print("*"*100) return { 'has_hs': True, "hs": new_traj } class Lemonade(BaseJuicer): ''' one sigle goal for each trajectory, mask out the irrelevant actions first_only: only include the first achieved goal ''' def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): super().__init__(llama=llama, api=api) self.task_instruction = task_instruction self.first_only = first_only def distill(self, new_trajectory): ''' re-examine the trajectory with relabled goal ''' def relabel_experience(self, state_history, obs, llm_out): act_obs_traj = '' for i, (x, y) in enumerate(zip(obs, llm_out)): if x.startswith("Observation: Error Input."): continue action = y.split("Action: ")[-1] action = f'Action="{action}"' obs = x.split("Observation: ")[-1] obs = f'Observation="{obs}"' act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" #if act_obs_traj == '': # return None relabel_inp = [ {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, ] chat_response = self._chat_completion(relabel_inp) if False: print() print("-"*100) print("Input:\n") print(relabel_inp) print() print("Output:\n") print(chat_response) print() print("Original goal:\n") print(state_history[0]['content']) print("-"*100) print() if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): new_goal = chat_response.split("Final goal: ")[-1] new_goal = remove_number(new_goal) else: return { "has_hs": False, "hs": chat_response } #new_goal = chat_response.split("\n")[-1] new_goal, _ = prompt_naive(self.task_instruction, new_goal) new_traj = copy.deepcopy(state_history) new_traj[0]['content'] = new_goal if False: print("*"*100) print("OLD TRAJ:\n") print(state_history) print() print("NEW TRAJ:\n") print(new_traj[0]) print("*"*100) return { 'has_hs': True, "hs": new_traj } class Lemontea(BaseJuicer): ''' multiple goals for each trajectory mask_out: mask out the actions irrelevant to any goal learn_explore: the goal space includes the random explroation ''' def __init__(self, task_instruction, llama="llama-3-1-70b", mask_out=False, learn_explore=False, api="internal"): super().__init__(llama=llama, api=api) self.task_instruction = task_instruction self.mask_out = mask_out self.learn_explore = learn_explore def distill(self, new_trajectory): ''' re-examine the trajectory with relabled goal ''' def relabel_experience(self, state_history, obs, llm_out): act_obs_traj = '' for i, (x, y) in enumerate(zip(obs, llm_out)): if x.startswith("Observation: Error Input."): continue action = y.split("Action: ")[-1] action = f'Action="{action}"' obs = x.split("Observation: ")[-1] obs = f'Observation="{obs}"' act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" #if act_obs_traj == '': # return None relabel_inp = [ {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, ] chat_response = self._chat_completion(relabel_inp) if False: print() print("-"*100) print("Input:\n") print(relabel_inp) print() print("Output:\n") print(chat_response) print() print("Original goal:\n") print(state_history[0]['content']) print("-"*100) print() if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): new_goal = chat_response.split("Final goal: ")[-1] new_goal = remove_number(new_goal) else: return { "has_hs": False, "hs": chat_response } #new_goal = chat_response.split("\n")[-1] new_goal, _ = prompt_naive(self.task_instruction, new_goal) new_traj = copy.deepcopy(state_history) new_traj[0]['content'] = new_goal if False: print("*"*100) print("OLD TRAJ:\n") print(state_history) print() print("NEW TRAJ:\n") print(new_traj[0]) print("*"*100) return { 'has_hs': True, "hs": new_traj }