|
|
| import base64, copy |
| import os |
|
|
| import time |
| import re |
| from openai import OpenAI |
| from eval_agent.prompt import prompt_naive, prompt_with_icl |
| from .prompts import ALFWORLD_GOAL_SYSTEM, ALFWORLD_ALLGOAL_SYSTEM, ALFWORLD_ATOMICGOAL_SYSTEM, ALFWORLD_RELEVANT_SYSTEM, ALFWORLD_ATOMICGOAL_NOSUGAR_SYSTEM, ALFWORLD_RELEVANT_STATE_SYSTEM, ALFWORLD_ALLGOAL_STATE_SYSTEM |
| from .prompts_webshop import WEBSHOP_GOAL_SYSTEM, WEBSHOP_QUERY_SYSTEM, WEBSHOP_QUERY_FREE_SYSTEM, WEBSHOP_RELEVANT_SYSTEM, WEBSHOP_GOAL_BREW_SYSTEM |
| from .hf_utils import history_to_sft_sample |
| from abc import ABC, abstractmethod |
|
|
| import json, random |
| from typing import List, Dict, Any |
| from .utils import map_to_random_milestone |
|
|
| def extract_json_list(text): |
| """Extract JSON list from text response.""" |
| pattern = r'\[.*?\]' |
| match = re.search(pattern, text, re.DOTALL) |
| if match: |
| json_str = match.group(0) |
| try: |
| return json.loads(json_str) |
| except json.JSONDecodeError: |
| return None |
| return None |
|
|
|
|
| def extract_webshop_query_list(text): |
| """ |
| Extract JSON list from WebShop query response, handling measurement notations. |
| Specifically handles patterns like 3'3" x 5'3" in product descriptions. |
| """ |
| pattern = r'\[.*?\]' |
| match = re.search(pattern, text, re.DOTALL) |
| if match: |
| json_str = match.group(0) |
| |
| |
| |
| def replace_newlines_in_strings(m): |
| |
| return m.group(0).replace('\n', ' ').replace('\r', ' ') |
| |
| |
| json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', replace_newlines_in_strings, json_str, flags=re.DOTALL) |
| |
| |
| |
| |
| json_str = re.sub(r'([^"\\]),\s*"', r'\1", "', json_str) |
| |
| |
| |
| |
| |
| |
| if re.search(r"\d+'\d+\"", json_str): |
| json_str = re.sub(r'"([^"]*?)(\d+\'[\d]+)"([^"]*?)"', r'"\1\2\\"\3"', json_str) |
| |
| while re.search(r'"([^"]*?)(\d+\'[\d]+)"([^"]*?)"', json_str): |
| json_str = re.sub(r'"([^"]*?)(\d+\'[\d]+)"([^"]*?)"', r'"\1\2\\"\3"', json_str) |
| |
| |
| if re.search(r'\d+"', json_str): |
| json_str = re.sub(r'"([^"]*?)(\d+)"([^"]*?)"', r'"\1\2\\"\3"', json_str) |
| |
| while re.search(r'"([^"]*?)(\d+)"([^"]*?)"', json_str): |
| json_str = re.sub(r'"([^"]*?)(\d+)"([^"]*?)"', r'"\1\2\\"\3"', json_str) |
| |
| try: |
| return json.loads(json_str) |
| except json.JSONDecodeError as e: |
| print(f"Failed to parse JSON: {e}") |
| print(f"Problematic JSON string: {json_str[:500]}...") |
| return None |
| return None |
|
|
|
|
|
|
| def parse_vitamin_output(text: str, *, return_json_safe: bool = False) -> List[Dict[str, Any]]: |
| """ |
| Robustly extract & parse the JSON list from messy LLM output. |
| Handles: |
| - Any backtick code fences (>=3), with/without language tag (```json / ````json / etc.) |
| - Leading prose |
| - // and /* ... */ comments |
| - Trailing commas before ] or } |
| - Tuple-like pairs in placement: [("a","b")] -> [["a","b"]] |
| - Normalizes tool_status 'watersource_on' -> 'faucet_on' |
| If return_json_safe=True: converts placement tuples back to lists for json.dumps. |
| """ |
|
|
| def _remove_comments(src: str) -> str: |
| out, i, n = [], 0, len(src) |
| in_str = False; quote = ''; line_cmt = False; block_cmt = False |
| while i < n: |
| c = src[i] |
| if line_cmt: |
| if c == '\n': line_cmt = False; out.append(c) |
| i += 1; continue |
| if block_cmt: |
| if c == '*' and i+1 < n and src[i+1] == '/': block_cmt = False; i += 2 |
| else: i += 1 |
| continue |
| if not in_str: |
| if c == '/' and i+1 < n: |
| if src[i+1] == '/': line_cmt = True; i += 2; continue |
| if src[i+1] == '*': block_cmt = True; i += 2; continue |
| if c in ('"', "'"): in_str = True; quote = c |
| out.append(c) |
| else: |
| out.append(c) |
| if c == '\\' and i+1 < n: |
| i += 1; out.append(src[i]) |
| elif c == quote: |
| in_str = False |
| i += 1 |
| return ''.join(out) |
|
|
| def _remove_trailing_commas(s: str) -> str: |
| out, i, n = [], 0, len(s) |
| in_str = False; quote = '' |
| while i < n: |
| c = s[i] |
| if not in_str: |
| if c in ('"', "'"): |
| in_str = True; quote = c; out.append(c) |
| elif c == ',': |
| j = i+1 |
| while j < n and s[j] in ' \t\r\n': j += 1 |
| if j < n and s[j] in '}]': i += 1; continue |
| out.append(c) |
| else: |
| out.append(c) |
| else: |
| out.append(c) |
| if c == '\\' and i+1 < n: |
| i += 1; out.append(s[i]) |
| elif c == quote: |
| in_str = False |
| i += 1 |
| return ''.join(out) |
|
|
| def _extract_code_fence(t: str) -> str: |
| |
| fence_pattern = re.compile(r"(`{3,})(?:[a-zA-Z]+)?\s*(.*?)\s*\1", re.DOTALL) |
| blocks = fence_pattern.findall(t) |
| for _, block in blocks: |
| if '[' in block and ']' in block: |
| return block |
| return t |
|
|
| def _extract_json_list(t: str) -> str: |
| t = _extract_code_fence(t) |
| start = t.find('[') |
| if start == -1: |
| raise ValueError("No JSON list found.") |
| depth = 0; end = None |
| for i in range(start, len(t)): |
| if t[i] == '[': depth += 1 |
| elif t[i] == ']': |
| depth -= 1 |
| if depth == 0: end = i; break |
| if end is None: |
| raise ValueError("Unbalanced brackets.") |
| return t[start:end+1] |
|
|
| def _fix_tuple_pairs(s: str) -> str: |
| |
| return re.sub(r'\(\s*("([^"\\]|\\.)*")\s*,\s*("([^"\\]|\\.)*")\s*\)', r'[\1, \3]', s) |
|
|
| blob = _extract_json_list(text) |
| blob = _remove_comments(blob) |
| blob = _fix_tuple_pairs(blob) |
| blob = _remove_trailing_commas(blob) |
| data = json.loads(blob) |
|
|
| |
| for obj in data: |
| if isinstance(obj, dict) and "tool_status" in obj and isinstance(obj["tool_status"], dict): |
| ts = obj["tool_status"] |
| if "watersource_on" in ts and "faucet_on" not in ts: |
| ts["faucet_on"] = bool(ts.pop("watersource_on")) |
| ts.setdefault("lamp_on", False) |
| ts.setdefault("faucet_on", False) |
| ts.setdefault("microwave_on", False) |
| ts.setdefault("fridge_closed", False) |
| if isinstance(obj, dict) and "placement" in obj and isinstance(obj["placement"], list): |
| obj["placement"] = [tuple(p) for p in obj["placement"] if isinstance(p, (list, tuple)) and len(p) == 2] |
|
|
| if return_json_safe: |
| for obj in data: |
| if isinstance(obj, dict) and "placement" in obj: |
| obj["placement"] = [list(p) if isinstance(p, tuple) else p for p in obj["placement"]] |
| return data |
|
|
|
|
| |
|
|
| def steps_to_hf_history(steps: List[Dict[str, Any]]) -> List[Dict[str, str]]: |
| """ |
| Convert a sequence of step dictionaries into a chat history for use with |
| history_to_sft_sample. The first observation in `steps` becomes the initial |
| user message, and each subsequent action becomes an assistant message followed |
| by the next observation as a user message. |
| """ |
| if not isinstance(steps, list) or not steps: |
| print("WARNING: the steps are empty!!!!!!!!!!!!!!!!!!") |
| print("Steps") |
| print(steps) |
| return [] |
|
|
| history: List[Dict[str, str]] = [] |
| for i, step in enumerate(steps): |
| |
| action_text = step.get("action", "") |
| if isinstance(action_text, list): |
| |
| try: |
| action_text = " ".join(str(x) for x in action_text) |
| except Exception: |
| action_text = json.dumps(action_text) |
| else: |
| action_text = str(action_text) |
|
|
| if not action_text.startswith("Action:"): |
| action_text = "Action: " + action_text |
|
|
| |
| is_rel = step.get("is_relevant_to_goal", "no") |
| if not isinstance(is_rel, str): |
| is_rel = "no" |
| is_rel = is_rel.lower() |
|
|
| history.append({ |
| "role": "assistant", |
| "content": action_text, |
| |
| "useful": "yes" in is_rel |
| }) |
|
|
| |
| obs_text = step.get("observation", "") |
| if isinstance(obs_text, list): |
| try: |
| obs_text = " ".join(str(x) for x in obs_text) |
| except Exception: |
| obs_text = json.dumps(obs_text) |
| else: |
| obs_text = str(obs_text) |
|
|
| if not obs_text.startswith("Observation:"): |
| obs_text = "Observation: " + obs_text |
|
|
| history.append({ |
| "role": "user", |
| "content": obs_text |
| }) |
|
|
| return history |
|
|
|
|
|
|
| def extract_env_desc(original_task): |
| original_task = original_task.split("\n") |
| if "Your task is to:" in original_task: |
| original_task = original_task[:-1] |
| return "\n".join(original_task) |
|
|
| def generate_trajgoal_pairs(steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Generate training samples (input-target pairs) from a list of step dictionaries. |
| |
| Each training sample corresponds to a distinct goal that the agent achieved. |
| For a goal first achieved at step k, the input is the list of step dictionaries |
| from the beginning up to and including step k, and the target is the goal |
| description. If no goal was achieved, this function returns an empty list. |
| |
| Args: |
| steps: A list of dictionaries representing the parsed JSON output. Each |
| dictionary in the list (except possibly the last) should contain |
| keys 'step', 'action', 'observation', 'location', 'inventory', and |
| 'reached_goals'. A final dictionary may contain keys 'final_goals' |
| and 'explanation' instead. |
| |
| Returns: |
| A list of dictionaries with keys: |
| 'input': list of step dictionaries up to and including the step |
| where a goal is first achieved, |
| 'target': the goal string achieved at that step. |
| If no goals are found, returns an empty list. |
| |
| Raises: |
| ValueError: If the input is not a list or contains no step dictionaries. |
| """ |
| if not isinstance(steps, list) or not steps: |
| raise ValueError("Input must be a non-empty list of step dictionaries.") |
|
|
| |
| seen_goals_order = [] |
| for step in steps: |
| for goal in step.get('reached_goals', []): |
| if goal not in seen_goals_order: |
| seen_goals_order.append(goal) |
| final_goals = seen_goals_order |
|
|
| |
| if not final_goals: |
| return [] |
|
|
| |
| goal_first_index: Dict[str, int] = {} |
| achieved_goals = set() |
| for idx, step in enumerate(steps): |
| for goal in step.get('reached_goals', []): |
| |
| if goal not in achieved_goals: |
| achieved_goals.add(goal) |
| goal_first_index[goal] = idx |
|
|
| |
| traj_goal_pairs = [] |
| for goal in final_goals: |
| if goal in goal_first_index: |
| idx = goal_first_index[goal] |
| |
| prefix_steps = steps[: idx + 1] |
| traj_goal_pairs.append({ |
| "trajectory": prefix_steps, |
| "goal": goal |
| }) |
|
|
| return traj_goal_pairs |
|
|
|
|
| """ |
| def extract_json_objects_from_output(output: str) -> List[Dict[str, Any]]: |
| # Regex captures the first '[' … ']' sequence containing at least one object. |
| match = re.search(r'\[\s*{.*?}\s*\]', output, re.DOTALL) |
| if not match: |
| print(output) |
| print("ValueError: No JSON array detected in the output string. Please check the LLM output format.") |
| return [] |
| json_str = match.group(0) |
| try: |
| return json.loads(json_str) |
| except json.JSONDecodeError as exc: |
| print(output) |
| print(f"ValueError: Failed to decode JSON array: {exc}") |
| return [] |
| """ |
|
|
|
|
| def extract_json_objects_from_output(output: str) -> List[Dict[str, Any]]: |
| """ |
| Extract the first well-balanced JSON array from a blob of text, clean |
| out common artifacts (ellipsis, trailing commas), and parse it. |
| """ |
| |
| start = output.find('[') |
| if start == -1: |
| raise ValueError("No '[' found in output; can't extract JSON array.") |
|
|
| |
| in_str = False |
| escape = False |
| depth = 0 |
| end = None |
|
|
| for idx, ch in enumerate(output[start:], start): |
| if in_str: |
| if escape: |
| escape = False |
| elif ch == '\\': |
| escape = True |
| elif ch == '"': |
| in_str = False |
| else: |
| if ch == '"': |
| in_str = True |
| elif ch == '[': |
| depth += 1 |
| elif ch == ']': |
| depth -= 1 |
| if depth == 0: |
| end = idx + 1 |
| break |
|
|
| if end is None: |
| raise ValueError("Unbalanced brackets: never found matching ']'.") |
|
|
| raw = output[start:end] |
|
|
| |
| raw = re.sub(r'//.*$', '', raw, flags=re.MULTILINE) |
| |
| raw = re.sub(r'^\s*\.\.\..*$\n?', '', raw, flags=re.MULTILINE) |
| |
| raw = re.sub(r',\s*(?=[}\]])', '', raw) |
|
|
| try: |
| return json.loads(raw) |
| except json.JSONDecodeError as exc: |
| |
| raise ValueError(f"Failed to decode JSON.\nError: {exc}") |
|
|
|
|
| def extract_webshop_json_objects(output: str) -> List[Dict[str, Any]]: |
| """ |
| WebShop-specific JSON extraction that handles common issues with WebShop data. |
| Specifically handles: |
| - Apostrophes in product names like "salt n' vinegar" |
| - Measurement notations like 26", 16" (inches) in search actions |
| - Extra text after the JSON array ends |
| """ |
| |
| start = output.find('[') |
| if start == -1: |
| raise ValueError("No '[' found in output; can't extract JSON array.") |
| |
| |
| bracket_depth = 0 |
| in_string = False |
| escape_next = False |
| end = -1 |
| |
| for i in range(start, len(output)): |
| char = output[i] |
| |
| if escape_next: |
| escape_next = False |
| continue |
| |
| if char == '\\': |
| escape_next = True |
| continue |
| |
| if char == '"' and not escape_next: |
| in_string = not in_string |
| continue |
| |
| if not in_string: |
| if char == '[': |
| bracket_depth += 1 |
| elif char == ']': |
| bracket_depth -= 1 |
| if bracket_depth == 0: |
| end = i + 1 |
| break |
| |
| if end == -1: |
| raise ValueError("No matching ']' found in output; can't extract JSON array.") |
| |
| raw = output[start:end] |
| |
| |
| |
| |
| |
| raw = raw.replace('\\\\"', '__ESCAPED_QUOTE__') |
| |
| |
| |
| raw = re.sub(r'(\d+(?:\.\d+)?)"(\s+\w|\w)', r'\1\\"\2', raw) |
| |
| |
| raw = raw.replace('__ESCAPED_QUOTE__', '\\"') |
|
|
| |
| raw = re.sub(r'//.*$', '', raw, flags=re.MULTILINE) |
| |
| raw = re.sub(r'^\s*\.\.\..*$\n?', '', raw, flags=re.MULTILINE) |
| |
| raw = re.sub(r',\s*(?=[}\]])', '', raw) |
| |
| |
| |
| |
| raw = raw.replace(r"\'", "'") |
| |
| raw = re.sub(r'\\u0027', "'", raw) |
| |
| try: |
| return json.loads(raw) |
| except json.JSONDecodeError as exc: |
| |
| try: |
| |
| |
| cleaned = re.sub(r'\\([\'"])', r'\1', raw) |
| return json.loads(cleaned) |
| except: |
| |
| print(f"Raw JSON that failed to parse:\n{raw[:500]}...") |
| raise ValueError(f"Failed to decode JSON.\nError: {exc}") |
|
|
|
|
| def extract_shopbrew_dict_from_output(output: str) -> Dict[str, Any]: |
| """ |
| Extract the first well-balanced JSON dictionary from a blob of text, clean |
| out common artifacts (ellipsis, trailing commas, comments), and parse it. |
| Similar to extract_json_objects_from_output but for dictionaries instead of lists. |
| Handles both single-quoted Python dicts and double-quoted JSON. |
| """ |
| |
| start = output.find('{') |
| if start == -1: |
| raise ValueError("No '{' found in output; can't extract JSON dictionary.") |
|
|
| |
| in_str = False |
| escape = False |
| depth = 0 |
| end = None |
| quote_char = None |
|
|
| for idx, ch in enumerate(output[start:], start): |
| if in_str: |
| if escape: |
| escape = False |
| elif ch == '\\': |
| escape = True |
| elif ch == quote_char: |
| in_str = False |
| quote_char = None |
| else: |
| if ch in ('"', "'"): |
| in_str = True |
| quote_char = ch |
| elif ch == '{': |
| depth += 1 |
| elif ch == '}': |
| depth -= 1 |
| if depth == 0: |
| end = idx + 1 |
| break |
|
|
| if end is None: |
| raise ValueError("Unbalanced braces: never found matching '}'.") |
|
|
| raw = output[start:end] |
|
|
| |
| raw = re.sub(r'//.*$', '', raw, flags=re.MULTILINE) |
| |
| raw = re.sub(r'^\s*\.\.\..*$\n?', '', raw, flags=re.MULTILINE) |
| |
| raw = re.sub(r',\s*(?=[}\]])', '', raw) |
|
|
| |
| try: |
| return json.loads(raw) |
| except json.JSONDecodeError: |
| |
| |
| fixed_bools = raw |
| |
| |
| fixed_bools = re.sub(r':\s*true\b', ': True', fixed_bools) |
| fixed_bools = re.sub(r':\s*false\b', ': False', fixed_bools) |
| fixed_bools = re.sub(r',\s*true\b', ', True', fixed_bools) |
| fixed_bools = re.sub(r',\s*false\b', ', False', fixed_bools) |
| fixed_bools = re.sub(r':\s*null\b', ': None', fixed_bools) |
| fixed_bools = re.sub(r',\s*null\b', ', None', fixed_bools) |
| |
| try: |
| |
| import ast |
| result = ast.literal_eval(fixed_bools) |
| if isinstance(result, dict): |
| return result |
| else: |
| raise ValueError(f"Parsed result is not a dictionary: {type(result)}") |
| except (ValueError, SyntaxError) as exc: |
| |
| try: |
| |
| json_str = fixed_bools.replace("'", '"') |
| json_str = re.sub(r'\bTrue\b', 'true', json_str) |
| json_str = re.sub(r'\bFalse\b', 'false', json_str) |
| json_str = re.sub(r'\bNone\b', 'null', json_str) |
| return json.loads(json_str) |
| except json.JSONDecodeError as exc2: |
| raise ValueError(f"Failed to decode JSON or Python dict.\nJSON Error: {exc2}\nAST Error: {exc}") |
|
|
|
|
| def remove_number(goal): |
| clean = re.sub(r'\b\d+\b', '', goal) |
| clean = " ".join(clean.split()) |
| return clean |
|
|
| class BaseJuicer(ABC): |
| """Base class for relabeling strategies using an OpenAI client.""" |
|
|
| def __init__(self, llama: str = "llama-3-1-70b", api: str = "internal") -> None: |
| if api == "internal": |
| endpoint = "http://pluto-prod-jackwa-llama-70b-infere-0:8000/v1" |
| key = "sk-QObXQcx0GTDciGVNhkgTLw" |
| api_key = "Bearer " + key |
| self.llama = "meta-llama/Llama-3.3-70B-Instruct" |
| elif api == "second": |
| endpoint = "http://pluto-prod-jkil-job2-0:8000/v1" |
| key = "sk-QObXQcx0GTDciGVNhkgTLw" |
| api_key = "Bearer " + key |
| self.llama = "meta-llama/Llama-3.3-70B-Instruct" |
| elif api == "legacy": |
| endpoint = "http://pluto-prod-hawang-llm-proxy-9qtfav-0:4000" |
| key = "sk--Tz-ELdSOZ-MFBRRomTuGg" |
| api_key = "Bearer " + key |
| self.llama = "gpt-4.1-mini" |
| elif api == "openrouter": |
| endpoint = "https://openrouter.ai/api/v1" |
| api_key = "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" |
| self.llama = "openai/gpt-4.1-mini" |
| elif api == "gpt": |
| endpoint = "https://api.openai.com/v1" |
| api_key = "sk-uuhu_uog_hb7OYM3xrzT_fFO-1KPUxPu2VOySrIXZGT3BlbkFJwPSRtJU5lEuyoWrf1Ayh8cEW265ABrvuCm6fwEAxUA" |
| self.llama = "gpt-4.1-mini-2025-04-14" |
| else: |
| raise ValueError(f"Unknown API option: {api}") |
|
|
| if api != "legacy": |
| self.llama_gen_params = {"temperature": 0} |
| else: |
| self.llama_gen_params = {} |
|
|
| self.client = OpenAI(api_key=api_key, base_url=endpoint) |
| |
| |
| |
| |
| |
| self.relabel_log_file: str | None = None |
| self.mask_log_file: str | None = None |
|
|
| self.log_interval: int = 30 |
| self.step_getter = None |
| |
| def _chat_completion(self, messages, mode): |
| """Invoke the OpenAI client with retry logic and optional logging.""" |
| attempt = 0 |
| while True: |
| try: |
| resp = self.client.chat.completions.create( |
| model=self.llama, messages=messages, **self.llama_gen_params |
| ) |
| resp_text = resp.choices[0].message.content |
| break |
| except Exception as exc: |
| attempt += 1 |
| if attempt >= 3: |
| print(f"Failed to send request: {exc}") |
| raise |
| time.sleep(1) |
| continue |
| |
| step = self.step_getter() |
| if step % self.log_interval == 0: |
| fn = self.relabel_log_file if mode == 'relabel' else self.mask_log_file |
| with open(fn, "a", encoding="utf-8") as f: |
| json.dump({ |
| "step": step, |
| "messages": messages, |
| "response": resp_text, |
| }, f) |
| f.write("\n") |
| return resp_text |
|
|
|
|
| def _rectify_chat_completion(self, original_messages, llm_response, error_msg): |
| """ |
| Request a corrected completion from the model. |
| |
| Parameters |
| ---------- |
| original_messages : List[dict] |
| Messages originally sent to the LLM. |
| llm_response : str |
| The model output that failed to parse. |
| error_msg : str |
| Explanation of the parsing failure. |
| |
| Returns |
| ------- |
| str |
| The model's new response after being asked to rectify the output. |
| """ |
| error_msg = "Your original output is in a WRONG format, and here is the error message " |
| messages = ( |
| original_messages |
| + [{"role": "assistant", "content": llm_response}] |
| + [{"role": "user", "content": error_msg}] |
| ) |
| return self._chat_completion(messages) |
|
|
| @abstractmethod |
| def relabel_experience(self, state_history, obs, llm_out, **kwargs): |
| """Return relabeled trajectory based on ``obs`` and ``llm_out``.""" |
| raise NotImplementedError |
|
|
|
|
| class Juicer(BaseJuicer): |
| def __init__(self, task_instruction, llama="llama-3-1-70b", api="internal"): |
| super().__init__(llama=llama, api=api) |
| self.task_instruction = task_instruction |
|
|
| def relabel_experience(self, state_history, obs, llm_out, **kwargs): |
| act_obs_traj = '' |
| for i, (x, y) in enumerate(zip(obs, llm_out)): |
| if x.startswith("Observation: Error Input."): |
| continue |
| action = y.split("Action: ")[-1] |
| action = f'Action="{action}"' |
| obs = x.split("Observation: ")[-1] |
| obs = f'Observation="{obs}"' |
| act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
|
|
| |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": ALFWORLD_GOAL_SYSTEM}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| if False: |
| print() |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print("-"*100) |
| print() |
| if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower(): |
| new_goal = chat_response.split("Final goal: ")[-1] |
| new_goal = remove_number(new_goal) |
| else: |
| return { |
| "has_hs": False, |
| "hs": chat_response |
| } |
| |
| new_traj = copy.deepcopy(state_history) |
| new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {new_goal}" |
| if False: |
| print("*"*100) |
| print("OLD TRAJ:\n") |
| print(state_history) |
| print() |
| print("NEW TRAJ:\n") |
| print(new_traj[0]) |
| print("*"*100) |
| return { |
| 'has_hs': True, |
| "hs": new_traj |
| } |
|
|
| class Lemonade(BaseJuicer): |
| ''' |
| one sigle goal for each trajectory, mask out the irrelevant actions |
| first_only: only include the first achieved goal |
| ''' |
| def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| super().__init__(llama=llama, api=api) |
| self.task_instruction = task_instruction |
| self.first_only = first_only |
| self._relabel_prompt = ALFWORLD_ALLGOAL_SYSTEM |
| self._relevant_prompt = ALFWORLD_RELEVANT_SYSTEM |
|
|
| def relabel_experience(self, state_history, obs, llm_out, original_task): |
| act_obs_traj = '' |
| for i, (x, y) in enumerate(zip(obs, llm_out)): |
| if x.startswith("Observation: Error Input."): |
| continue |
| action = y.split("Action: ")[-1] |
| action = f'Action="{action}"' |
| obs = x.split("Observation: ")[-1] |
| obs = f'Observation="{obs}"' |
| act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
|
|
| |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": self._relabel_prompt}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
|
|
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| |
| try: |
| trajs = extract_json_objects_from_output(chat_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}") |
| trajs = [] |
| |
| if not trajs: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
| if "final_goals" in trajs[-1].keys(): |
| trajs_final = trajs[-1] |
| trajs = trajs[:-1] |
| else: |
| print("WARNING: No finals goal from relabel model!!!!!") |
| print(trajs) |
| trajs_final = None |
| |
| if False: |
| print("Trajs") |
| print(trajs) |
| assert False |
| |
| traj_goal_pairs = generate_trajgoal_pairs(trajs) |
| if len(traj_goal_pairs) == 0: |
| return { |
| "has_hs": False, |
| "hs": chat_response |
| } |
| |
| if self.first_only: |
| traj_goal_pairs = [traj_goal_pairs[0]] |
| if False: |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print("-"*100) |
| print() |
| env_desc = extract_env_desc(original_task) |
| hs = [] |
| |
| for step in traj_goal_pairs: |
| traj = step['trajectory'] |
| goal = remove_number(step['goal']) |
| goal_relevant_inp = [ |
| {"role": "system", "content": self._relevant_prompt}, |
| { |
| "role": "user", |
| "content": f"Environment description:\n{env_desc}\n\nGoal:{goal}\n\nHere is the tracjtory:\n{traj}\n\nNow, please judge the relevance of actions at each step." |
| }, |
| ] |
| goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask") |
| try: |
| traj_goal_pairs_with_relevance = extract_json_objects_from_output(goal_relevant_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_inp}") |
| continue |
| if False: |
| print("<"*100) |
| print("GOAL Relevance Input:\n") |
| print(goal_relevant_inp) |
| print() |
| print("GOAL Relevance Output:\n") |
| print(goal_relevant_response) |
| print() |
| print("Extracted GOAL Relevance Output:\n") |
| print(traj_goal_pairs_with_relevance) |
| print(">"*100) |
| print() |
| new_traj = copy.deepcopy(state_history) |
| new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {goal}" |
| new_traj = [new_traj[0]] + steps_to_hf_history(traj_goal_pairs_with_relevance) |
| hs.append(new_traj) |
| if False: |
| print("|"*100) |
| print("OLD TRAJ:\n") |
| print(state_history) |
| print() |
| print("NEW TRAJ:\n") |
| print(new_traj) |
| print("|"*100) |
| if hs: |
| return { |
| "has_hs": True, |
| "hs": hs |
| } |
| else: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
|
|
|
|
| class Vitamin(Lemonade): |
| """ |
| Same behavior as `Lemonade` but uses different system prompts. |
| Only overrides the prompts; everything else stays identical. |
| """ |
| def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| |
| super().__init__(task_instruction, llama=llama, first_only=first_only, api=api) |
|
|
| |
| self._relabel_prompt = ALFWORLD_ALLGOAL_STATE_SYSTEM |
| self._relevant_prompt = ALFWORLD_RELEVANT_STATE_SYSTEM |
| |
| def relabel_experience(self, state_history, obs, llm_out, original_task): |
| act_obs_traj = '' |
| for i, (x, y) in enumerate(zip(obs, llm_out)): |
| if x.startswith("Observation: Error Input."): |
| continue |
| action = y.split("Action: ")[-1] |
| action = f'Action="{action}"' |
| obs = x.split("Observation: ")[-1] |
| obs = f'Observation="{obs}"' |
| act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
|
|
| |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": self._relabel_prompt}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
|
|
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| |
| try: |
| trajs = parse_vitamin_output(chat_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}") |
| trajs = [] |
| |
| if not trajs: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
| if "final_goals" in trajs[-1].keys(): |
| trajs_final = trajs[-1] |
| trajs = trajs[:-1] |
| else: |
| print("WARNING: No finals goal from relabel model!!!!!") |
| print(trajs) |
| trajs_final = None |
| |
| if False: |
| print("Trajs") |
| print(trajs) |
| assert False |
| |
| traj_goal_pairs = generate_trajgoal_pairs(trajs) |
| if len(traj_goal_pairs) == 0: |
| return { |
| "has_hs": False, |
| "hs": chat_response |
| } |
| |
| if self.first_only: |
| traj_goal_pairs = [traj_goal_pairs[0]] |
| if False: |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print("TRAJ\n") |
| print(trajs) |
| print("-"*100) |
| print() |
| env_desc = extract_env_desc(original_task) |
| hs = [] |
| |
| for step in traj_goal_pairs: |
| traj = step['trajectory'] |
| goal = remove_number(step['goal']) |
| goal_relevant_inp = [ |
| {"role": "system", "content": self._relevant_prompt}, |
| { |
| "role": "user", |
| "content": f"Environment description:\n{env_desc}\n\nGoal:{goal}\n\nHere is the tracjtory:\n{traj}\n\nNow, please judge the relevance of actions at each step." |
| }, |
| ] |
| goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask") |
| try: |
| traj_goal_pairs_with_relevance = extract_json_objects_from_output(goal_relevant_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_inp}") |
| assert False |
| continue |
| if False: |
| print("<"*100) |
| print("GOAL Relevance Input:\n") |
| print(goal_relevant_inp) |
| print() |
| print("GOAL Relevance Output:\n") |
| print(goal_relevant_response) |
| print() |
| print("Extracted GOAL Relevance Output:\n") |
| print(traj_goal_pairs_with_relevance) |
| print(">"*100) |
| print() |
| new_traj = copy.deepcopy(state_history) |
| new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {goal}" |
| new_traj = [new_traj[0]] + steps_to_hf_history(traj_goal_pairs_with_relevance) |
| hs.append(new_traj) |
| if False: |
| print("|"*100) |
| print("OLD TRAJ:\n") |
| print(state_history) |
| print() |
| print("NEW TRAJ:\n") |
| print(new_traj) |
| print("|"*100) |
| if hs: |
| return { |
| "has_hs": True, |
| "hs": hs |
| } |
| else: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
|
|
|
|
|
|
| class OTea(BaseJuicer): |
| ''' |
| one sigle goal for each trajectory, mask out the irrelevant actions |
| first_only: only include the first achieved goal |
| ''' |
| def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| super().__init__(llama=llama, api=api) |
| self.task_instruction = task_instruction |
| self.first_only = first_only |
|
|
| def distill(self, new_trajectory): |
| ''' |
| re-examine the trajectory with relabled goal |
| ''' |
| return |
|
|
| def relabel_experience(self, state_history, obs, llm_out, original_task): |
| act_obs_traj = '' |
| for i, (x, y) in enumerate(zip(obs, llm_out)): |
| if x.startswith("Observation: Error Input."): |
| continue |
| action = y.split("Action: ")[-1] |
| action = f'Action="{action}"' |
| obs = x.split("Observation: ")[-1] |
| obs = f'Observation="{obs}"' |
| act_obs_traj += f"Step {i+1}: {action}; {obs}.\n" |
|
|
| |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": ALFWORLD_ATOMICGOAL_NOSUGAR_SYSTEM}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| try: |
| trajs = extract_json_objects_from_output(chat_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}") |
| trajs = [] |
| if not trajs: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
| if "final_goals" in trajs[-1].keys(): |
| trajs_final = trajs[-1] |
| trajs = trajs[:-1] |
| else: |
| print("WARNING: No finals goal from relabel model!!!!!") |
| print(trajs) |
| trajs_final = None |
| |
| if False: |
| print("Trajs") |
| print(trajs) |
| assert False |
| |
| traj_goal_pairs = generate_trajgoal_pairs(trajs) |
| if len(traj_goal_pairs) == 0: |
| return { |
| "has_hs": False, |
| "hs": chat_response |
| } |
| |
| if False and self.first_only: |
| traj_goal_pairs = [traj_goal_pairs[0]] |
| if False: |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print("-"*100) |
| print() |
| env_desc = extract_env_desc(original_task) |
| hs = [] |
| |
| for step in traj_goal_pairs: |
| traj = step['trajectory'] |
| goal = remove_number(step['goal']) |
| goal_relevant_inp = [ |
| {"role": "system", "content": ALFWORLD_RELEVANT_SYSTEM}, |
| { |
| "role": "user", |
| "content": f"Environment description:\n{env_desc}\n\nGoal:{goal}\n\nHere is the tracjtory:\n{traj}\n\nNow, please judge the relevance of actions at each step." |
| }, |
| ] |
| goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask") |
| try: |
| traj_goal_pairs_with_relevance = extract_json_objects_from_output(goal_relevant_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_inp}") |
| continue |
| if False: |
| print("<"*100) |
| print("GOAL Relevance Input:\n") |
| print(goal_relevant_inp) |
| print() |
| print("GOAL Relevance Output:\n") |
| print(goal_relevant_response) |
| print() |
| print("Extracted GOAL Relevance Output:\n") |
| print(traj_goal_pairs_with_relevance) |
| print(">"*100) |
| print() |
| new_traj = copy.deepcopy(state_history) |
| new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {goal}" |
| new_traj = [new_traj[0]] + steps_to_hf_history(traj_goal_pairs_with_relevance) |
| hs.append(new_traj) |
| if False: |
| print("|"*100) |
| print("OLD TRAJ:\n") |
| print(state_history) |
| print() |
| print("NEW TRAJ:\n") |
| print(new_traj) |
| print("|"*100) |
| if hs: |
| return { |
| "has_hs": True, |
| "hs": hs |
| } |
| else: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
|
|
|
|
| def build_webshop_trajectory(obs, llm_out): |
| """ |
| Build action-observation trajectory string for WebShop with clean website observations. |
| Removes instruction prefixes to keep only actual website content. |
| """ |
| import re |
| |
| act_obs_traj = '' |
| for i, (x, y) in enumerate(zip(obs, llm_out)): |
| |
| if "Error" in x or "error" in x: |
| continue |
| |
| |
| if "Action: " in y: |
| action = y.split("Action: ")[-1].strip() |
| else: |
| action = y.strip() |
| action = f'Action="{action}"' |
| |
| |
| if "Observation:" in x: |
| obs_text = x.split("Observation:")[-1].strip() |
| else: |
| obs_text = x.strip() |
| |
| |
| obs_text = re.sub(r'^Instruction:\s*\[SEP\].*?\[SEP\]\s*', '', obs_text) |
| |
| |
| obs_text = re.sub(r'\[SEP\]\s*Reward\s*\[SEP\].*$', '', obs_text) |
| |
| obs_text = f'Observation="{obs_text}"' |
| act_obs_traj += f"Step {i+1}: {action}; {obs_text}.\n" |
| |
| return act_obs_traj |
|
|
|
|
| def build_webshop_trajectory_and_intention(obs, llm_out, intentions_traj): |
| """ |
| Build action-observation trajectory string for WebShop with clean website observations. |
| Removes instruction prefixes to keep only actual website content. |
| """ |
| import re |
| |
| act_obs_intent_tracj = "" |
| for i, (x, y, z) in enumerate(zip(obs, llm_out, intentions_traj)): |
| |
| if "Error" in x or "error" in x: |
| continue |
| |
| |
| if "Action: " in y: |
| action = y.split("Action: ")[-1].strip() |
| else: |
| action = y.strip() |
| action = f'Action="{action}"' |
| |
| |
| if "Observation:" in x: |
| obs_text = x.split("Observation:")[-1].strip() |
| else: |
| obs_text = x.strip() |
| |
| |
| obs_text = re.sub(r'^Instruction:\s*\[SEP\].*?\[SEP\]\s*', '', obs_text) |
| |
| |
| obs_text = re.sub(r'\[SEP\]\s*Reward\s*\[SEP\].*$', '', obs_text) |
| |
| obs_text = f'Observation="{obs_text}"' |
|
|
| intention_text = f'current_intention={z}' |
| act_obs_intent_tracj += f"Step {i+1}: {action}; {obs_text}; {intention_text}.\n" |
| |
| return act_obs_intent_tracj |
|
|
|
|
|
|
| def webshop_steps_to_hf_history(steps: List[Dict[str, Any]]) -> List[Dict[str, str]]: |
| """ |
| Convert a sequence of step dictionaries into a chat history for use with |
| history_to_sft_sample. The first observation in `steps` becomes the initial |
| user message, and each subsequent action becomes an assistant message followed |
| by the next observation as a user message. |
| """ |
| if not isinstance(steps, list) or not steps: |
| print("WARNING: the steps are empty!!!!!!!!!!!!!!!!!!") |
| print("Steps") |
| print(steps) |
| return [] |
|
|
| history: List[Dict[str, str]] = [] |
| for i, step in enumerate(steps): |
| |
| action_text = step.get("action", "") |
| if isinstance(action_text, list): |
| |
| try: |
| action_text = " ".join(str(x) for x in action_text) |
| except Exception: |
| action_text = json.dumps(action_text) |
| else: |
| action_text = str(action_text) |
|
|
| if not action_text.startswith("Action:"): |
| action_text = "Action: " + action_text |
|
|
| |
| is_rel = step.get("is_relevant_to_goal", "no") |
| if not isinstance(is_rel, str): |
| is_rel = "no" |
| is_rel = is_rel.lower() |
|
|
| history.append({ |
| "role": "assistant", |
| "content": action_text, |
| |
| "useful": "yes" in is_rel |
| }) |
|
|
| |
| obs_text = step.get("observation", "") |
| if isinstance(obs_text, list): |
| try: |
| obs_text = " ".join(str(x) for x in obs_text) |
| except Exception: |
| obs_text = json.dumps(obs_text) |
| else: |
| obs_text = str(obs_text) |
|
|
| if not obs_text.startswith("Observation:"): |
| obs_text = "Observation: " + obs_text |
|
|
| history.append({ |
| "role": "user", |
| "content": obs_text |
| }) |
|
|
| return history |
|
|
| class ShopLemonade(BaseJuicer): |
| """ |
| WebShop-specific implementation of Lemonade relabeling. |
| Uses WebShop-specific prompts for e-commerce navigation tasks. |
| """ |
| def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| super().__init__(llama=llama, api=api) |
| self.task_instruction = task_instruction |
| self.first_only = first_only |
| |
| self._relabel_prompt = WEBSHOP_GOAL_SYSTEM |
| self._query_prompt = WEBSHOP_QUERY_SYSTEM |
|
|
| def relabel_experience(self, state_history, obs, llm_out, original_task): |
| """ |
| Relabel WebShop trajectories with hindsight goals. |
| Adapted from Lemonade but handles WebShop's action/observation format. |
| """ |
| |
| act_obs_traj = build_webshop_trajectory(obs, llm_out) |
| |
| if not act_obs_traj: |
| return { |
| "has_hs": False, |
| "hs": "No valid action-observation pairs found" |
| } |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": self._relabel_prompt}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
|
|
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| |
| try: |
| trajs = extract_json_objects_from_output(chat_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}") |
| trajs = [] |
| |
| if not trajs: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
| |
| |
| if "purchase_success" in trajs[-1].keys(): |
| purchase_success = trajs[-1]['purchase_success'] |
| trajs_final = trajs[-2] |
| trajs = trajs[:-1] |
| else: |
| print("WARNING: No final goals from relabel model!") |
| trajs_final = None |
| |
|
|
| if not purchase_success: |
| return { |
| "has_hs": False, |
| "hs": f"Failed to buy anything!!!" |
| } |
|
|
| |
| trajs_final = {k:v for k,v in trajs_final.items() if "step" not in k.lower()} |
| query_inp = [ |
| {"role": "system", "content": self._query_prompt}, |
| { |
| "role": "user", |
| "content": str(trajs_final) |
| }, |
| ] |
| query_raw = self._chat_completion(query_inp, "mask") |
| |
| |
| query = extract_json_list(query_raw) |
| if query is None: |
| print(f"ERROR: Failed to parse JSON list from query response: {query_raw}") |
| return { |
| "has_hs": False, |
| "hs": f"Failed to parse query!!!" |
| } |
|
|
| query = random.choice(query) |
|
|
| if False: |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print() |
| print("Query:\n") |
| print(query) |
| print("-"*100) |
| print() |
|
|
| new_traj = copy.deepcopy(state_history) |
| new_traj = self._update_webshop_goal(new_traj, query) |
|
|
| hs = [new_traj] |
| if False: |
| print("<"*100) |
| print("GOAL Relevance Input:\n") |
| print(goal_relevant_inp) |
| print() |
| print("GOAL Relevance Output:\n") |
| print(goal_relevant_response) |
| print() |
| print("Extracted GOAL Relevance Output:\n") |
| print(new_traj) |
| print(">"*100) |
| print() |
|
|
| if False: |
| print("|"*100) |
| print("OLD TRAJ:\n") |
| print(state_history) |
| print() |
| print("NEW TRAJ:\n") |
| print(new_traj) |
| print("|"*100) |
| |
| if hs: |
| return { |
| "has_hs": True, |
| "hs": hs |
| } |
| else: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to generate hindsight trajectories" |
| } |
| |
| def _generate_webshop_trajgoal_pairs(self, trajs): |
| """ |
| Generate trajectory-goal pairs for WebShop. |
| Similar to generate_trajgoal_pairs but adapted for WebShop format. |
| """ |
| traj_goal_pairs = [] |
| current_goals = [] |
| |
| for step in trajs: |
| if 'reached_goals' in step and step['reached_goals']: |
| |
| new_goals = [g for g in step['reached_goals'] if g not in current_goals] |
| for goal in new_goals: |
| |
| goal_start_idx = 0 |
| for i, s in enumerate(trajs[:step['step']]): |
| if 'reached_goals' in s and goal in s['reached_goals']: |
| goal_start_idx = i |
| break |
| |
| |
| traj_text = "" |
| for s in trajs[goal_start_idx:step['step']]: |
| traj_text += f"Step {s['step']}: Action=\"{s['action']}\"; Observation=\"{s['observation']}\"\n" |
| if 'page_type' in s: |
| traj_text += f" Page: {s['page_type']}" |
| if 'products_viewed' in s and s['products_viewed']: |
| traj_text += f", Products viewed: {len(s['products_viewed'])}" |
| traj_text += "\n" |
| |
| traj_goal_pairs.append({ |
| 'trajectory': traj_text, |
| 'goal': goal |
| }) |
| current_goals.extend(new_goals) |
| |
| return traj_goal_pairs |
| |
| def _extract_webshop_env_desc(self, original_task): |
| """ |
| Extract the WebShop task description (product search query). |
| """ |
| if isinstance(original_task, str): |
| |
| if "Instruction:" in original_task: |
| return original_task.split("Instruction:")[-1].split("\n")[0].strip() |
| return original_task.strip() |
| return "Shop for products online" |
| |
| def _update_webshop_goal(self, trajectory, new_goal): |
| """ |
| Update ALL WebShop instructions in the trajectory with the hindsight goal. |
| WebShop repeats the instruction in every human message. |
| """ |
| updated_trajectory = [] |
| |
| for turn in trajectory: |
| updated_turn = copy.deepcopy(turn) |
| |
| |
| if turn.get("role") == "user" and "Instruction: [SEP]" in turn.get("content", ""): |
| content = turn["content"] |
| |
| |
| inst_start = content.find("Instruction: [SEP]") |
| prefix = content[:inst_start + len("Instruction: [SEP]")] |
| |
| |
| remaining = content[len(prefix):] |
| next_sep = remaining.find("[SEP]") |
| |
| if next_sep != -1: |
| |
| suffix = remaining[next_sep:] |
| updated_turn["content"] = f"{prefix} {new_goal} {suffix}" |
| else: |
| |
| updated_turn["content"] = f"{prefix} {new_goal}" |
| |
| updated_trajectory.append(updated_turn) |
| |
| return updated_trajectory |
|
|
|
|
|
|
|
|
| class ShopOTea(BaseJuicer): |
| """ |
| WebShop-specific implementation of Lemonade relabeling. |
| Uses WebShop-specific prompts for e-commerce navigation tasks. |
| """ |
| def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| super().__init__(llama=llama, api=api) |
| self.task_instruction = task_instruction |
| self.first_only = first_only |
| |
| self._relabel_prompt = WEBSHOP_GOAL_SYSTEM |
| self._query_prompt = WEBSHOP_QUERY_FREE_SYSTEM |
| self._relevant_prompt = WEBSHOP_RELEVANT_SYSTEM |
|
|
| def relabel_experience(self, state_history, obs, llm_out, original_task): |
| """ |
| Relabel WebShop trajectories with hindsight goals. |
| Adapted from Lemonade but handles WebShop's action/observation format. |
| """ |
| |
| act_obs_traj = build_webshop_trajectory(obs, llm_out) |
| |
| if not act_obs_traj: |
| return { |
| "has_hs": False, |
| "hs": "No valid action-observation pairs found" |
| } |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": self._relabel_prompt}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
|
|
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| |
| try: |
| intentions = extract_json_objects_from_output(chat_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}") |
| intentions = [] |
| |
| if not intentions: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
| |
| |
| if "purchase_success" in intentions[-1].keys(): |
| purchase_success = intentions[-1]['purchase_success'] |
| intentions_final = intentions[-2] |
| intentions = intentions[:-1] |
| else: |
| print("WARNING: No final goals from relabel model!") |
| intentions_final = None |
| |
| |
| if not purchase_success: |
| print("WARNING: Failed to buy anything!!!") |
| return { |
| "has_hs": False, |
| "hs": "Failed to buy anything!!!" |
| } |
| |
| intentions_final = {k:v for k,v in intentions_final.items() if "step" not in k.lower()} |
|
|
| act_obs_intent_traj = build_webshop_trajectory_and_intention(obs, llm_out, intentions) |
| |
| goal_relevant_inp = [ |
| {"role": "system", "content": self._relevant_prompt}, |
| { |
| "role": "user", |
| "content": f"Shopping intention:{intentions_final}\n\nHere is the tracjtory:\n{act_obs_intent_traj}\n\nNow, please judge the relevance of actions at each step." |
| }, |
| ] |
|
|
| goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask") |
| try: |
| |
| traj_goal_pairs_with_relevance = extract_webshop_json_objects(goal_relevant_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_response}") |
| assert False |
|
|
| |
| |
| query_inp = [ |
| {"role": "system", "content": self._query_prompt}, |
| { |
| "role": "user", |
| "content": str(intentions_final) |
| }, |
| ] |
| query_raw = self._chat_completion(query_inp, "mask") |
| |
| |
| query = extract_webshop_query_list(query_raw) |
| if query is None: |
| print(f"ERROR: Failed to parse JSON list from query response: {query_raw}") |
| return { |
| "has_hs": False, |
| "hs": f"Failed to parse query!!!" |
| } |
|
|
| query = min(query, key=len).lower() |
|
|
| if False: |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print() |
| print("Query:\n") |
| print(query) |
| print("-"*100) |
| print() |
|
|
| new_traj = copy.deepcopy(state_history) |
| new_traj = self._update_webshop_goal(new_traj, query.lower()) |
| print(f"DEBUG: Starting to add 'useful' key to trajectory with {len(new_traj)} messages") |
| print(f"DEBUG: traj_goal_pairs_with_relevance has {len(traj_goal_pairs_with_relevance)} items") |
| |
| for i, step in enumerate(new_traj): |
| if step['role'] == "assistant": |
| try: |
| relevance_idx = int(i//2) |
| print(f"DEBUG: Processing assistant message at i={i}, using relevance_idx={relevance_idx}") |
| |
| if relevance_idx >= len(traj_goal_pairs_with_relevance): |
| print(f"ERROR: relevance_idx {relevance_idx} >= len(traj_goal_pairs_with_relevance) {len(traj_goal_pairs_with_relevance)}") |
| print(f"DEBUG: traj_goal_pairs_with_relevance = {traj_goal_pairs_with_relevance}") |
| assert False |
| |
| relevance_value = traj_goal_pairs_with_relevance[relevance_idx].get('relevance', '') |
| print(f"DEBUG: relevance_value = {relevance_value}") |
| |
| new_traj[i]['useful'] = True if 'yes' in relevance_value.lower() else False |
| print(f"DEBUG: Set useful={new_traj[i]['useful']} for assistant message at i={i}") |
| except Exception as exc: |
| print(f"ERROR: Failed to extract the relevance label:\nError: {exc}\n{(int(i//2), len(traj_goal_pairs_with_relevance))}") |
| print(f"DEBUG: Full exception details: {exc}") |
| import traceback |
| traceback.print_exc() |
| assert False |
| |
| print(f"DEBUG: Finished adding 'useful' key. Checking results...") |
| for i, msg in enumerate(new_traj): |
| if msg['role'] == 'assistant': |
| print(f" Assistant message {i}: has_useful={'useful' in msg}, useful={msg.get('useful', 'NOT SET')}") |
|
|
| hs = [new_traj] |
| if False: |
| print("<"*100) |
| print("GOAL Relevance Input:\n") |
| print(goal_relevant_inp) |
| print() |
| print("GOAL Relevance Output:\n") |
| print(goal_relevant_response) |
| print() |
| print("Extracted GOAL Relevance Output:\n") |
| print(new_traj) |
| print(">"*100) |
| print() |
|
|
| if False: |
| print("|"*100) |
| print("OLD TRAJ:\n") |
| |
| for x in state_history: |
| print(x) |
| print() |
| print("NEW TRAJ:\n") |
| for x in new_traj: |
| print(x) |
| print("|"*100) |
| |
| if hs: |
| return { |
| "has_hs": True, |
| "hs": hs |
| } |
| else: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to generate hindsight trajectories" |
| } |
| |
| def _generate_webshop_trajgoal_pairs(self, trajs): |
| """ |
| Generate trajectory-goal pairs for WebShop. |
| Similar to generate_trajgoal_pairs but adapted for WebShop format. |
| """ |
| traj_goal_pairs = [] |
| current_goals = [] |
| |
| for step in trajs: |
| if 'reached_goals' in step and step['reached_goals']: |
| |
| new_goals = [g for g in step['reached_goals'] if g not in current_goals] |
| for goal in new_goals: |
| |
| goal_start_idx = 0 |
| for i, s in enumerate(trajs[:step['step']]): |
| if 'reached_goals' in s and goal in s['reached_goals']: |
| goal_start_idx = i |
| break |
| |
| |
| traj_text = "" |
| for s in trajs[goal_start_idx:step['step']]: |
| traj_text += f"Step {s['step']}: Action=\"{s['action']}\"; Observation=\"{s['observation']}\"\n" |
| if 'page_type' in s: |
| traj_text += f" Page: {s['page_type']}" |
| if 'products_viewed' in s and s['products_viewed']: |
| traj_text += f", Products viewed: {len(s['products_viewed'])}" |
| traj_text += "\n" |
| |
| traj_goal_pairs.append({ |
| 'trajectory': traj_text, |
| 'goal': goal |
| }) |
| current_goals.extend(new_goals) |
| |
| return traj_goal_pairs |
| |
| def _extract_webshop_env_desc(self, original_task): |
| """ |
| Extract the WebShop task description (product search query). |
| """ |
| if isinstance(original_task, str): |
| |
| if "Instruction:" in original_task: |
| return original_task.split("Instruction:")[-1].split("\n")[0].strip() |
| return original_task.strip() |
| return "Shop for products online" |
| |
| def _update_webshop_goal(self, trajectory, new_goal): |
| """ |
| Update ALL WebShop instructions in the trajectory with the hindsight goal. |
| WebShop repeats the instruction in every human message. |
| """ |
| updated_trajectory = [] |
| |
| for turn in trajectory: |
| updated_turn = copy.deepcopy(turn) |
| |
| |
| if turn.get("role") == "user" and "Instruction: [SEP]" in turn.get("content", ""): |
| content = turn["content"] |
| |
| |
| inst_start = content.find("Instruction: [SEP]") |
| prefix = content[:inst_start + len("Instruction: [SEP]")] |
| |
| |
| remaining = content[len(prefix):] |
| next_sep = remaining.find("[SEP]") |
| |
| if next_sep != -1: |
| |
| suffix = remaining[next_sep:] |
| updated_turn["content"] = f"{prefix} {new_goal} {suffix}" |
| else: |
| |
| updated_turn["content"] = f"{prefix} {new_goal}" |
| |
| updated_trajectory.append(updated_turn) |
| |
| return updated_trajectory |
|
|
|
|
|
|
|
|
| class ShopBrew(BaseJuicer): |
| """ |
| WebShop-specific implementation of Lemonade relabeling. |
| Uses WebShop-specific prompts for e-commerce navigation tasks. |
| """ |
| def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"): |
| super().__init__(llama=llama, api=api) |
| self.task_instruction = task_instruction |
| self.first_only = first_only |
| |
| self._relabel_prompt = WEBSHOP_GOAL_BREW_SYSTEM |
| self._query_prompt = WEBSHOP_QUERY_FREE_SYSTEM |
| self._relevant_prompt = WEBSHOP_RELEVANT_SYSTEM |
|
|
| def relabel_experience(self, state_history, obs, llm_out, original_task): |
| """ |
| Relabel WebShop trajectories with hindsight goals. |
| Adapted from Lemonade but handles WebShop's action/observation format. |
| """ |
| |
| act_obs_traj = build_webshop_trajectory(obs, llm_out) |
| |
| if not act_obs_traj: |
| return { |
| "has_hs": False, |
| "hs": "No valid action-observation pairs found" |
| } |
| |
| |
| relabel_inp = [ |
| {"role": "system", "content": self._relabel_prompt}, |
| {"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"}, |
| ] |
|
|
| chat_response = self._chat_completion(relabel_inp, "relabel") |
| |
| try: |
| intentions = extract_shopbrew_dict_from_output(chat_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}") |
| intentions = {} |
| |
| if not intentions: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}" |
| } |
| |
| |
| purchase_success = intentions['purchase_success'] and intentions['query_satisfaction'] |
| intentions_final = intentions['selected'] |
| if intentions_final is None: |
| print(f"Warning: failed to extract price {intentions}") |
| return { |
| "has_hs": False, |
| "hs": "Failed to extract price!!!" |
| } |
| intentions_final['price_limit'] = intentions_final.pop('price') |
| if not isinstance(intentions_final['price_limit'], int) and not isinstance(intentions_final['price_limit'], float): |
| print("WARNING: Failed extract price!!!") |
| return { |
| "has_hs": False, |
| "hs": "Failed to extract price!!!" |
| } |
|
|
| intentions_final['price_limit'] = map_to_random_milestone(intentions_final['price_limit']) |
| |
| if not purchase_success: |
| print("WARNING: Failed to buy anything!!!") |
| return { |
| "has_hs": False, |
| "hs": "Failed to buy anything!!!" |
| } |
|
|
| |
| |
| |
| act_obs_intent_traj = build_webshop_trajectory(obs, llm_out) |
| |
| goal_relevant_inp = [ |
| {"role": "system", "content": self._relevant_prompt}, |
| { |
| "role": "user", |
| "content": f"Shopping intention:{intentions_final}\n\nHere is the tracjtory:\n{act_obs_intent_traj}\n\nNow, please judge the relevance of actions at each step." |
| }, |
| ] |
|
|
| goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask") |
| try: |
| |
| traj_goal_pairs_with_relevance = extract_webshop_json_objects(goal_relevant_response) |
| except Exception as exc: |
| print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_response}") |
| assert False |
|
|
| |
| query = f"{intentions['query']}, price lower than {intentions_final['price_limit']} dollars" |
| query_inp = [ |
| {"role": "system", "content": self._query_prompt}, |
| { |
| "role": "user", |
| "content": f"Here is the search query: {query}. Now transfrom it to three diverse and complete sentences." |
| }, |
| ] |
| query_raw = self._chat_completion(query_inp, "mask") |
| |
| |
| query = extract_webshop_query_list(query_raw) |
| if query is None: |
| print(f"ERROR: Failed to parse JSON list from query response: {query_raw}") |
| return { |
| "has_hs": False, |
| "hs": f"Failed to parse query!!!" |
| } |
|
|
| query = random.choice(query) |
|
|
| if False: |
| print("-"*100) |
| print("Input:\n") |
| print(relabel_inp) |
| print() |
| print("Output:\n") |
| print(chat_response) |
| print() |
| print("Original goal:\n") |
| print(state_history[0]['content']) |
| print() |
| print("Query:\n") |
| print(query) |
| print("-"*100) |
| print() |
|
|
| new_traj = copy.deepcopy(state_history) |
| new_traj = self._update_webshop_goal(new_traj, query.lower()) |
| |
| |
| |
| for i, step in enumerate(new_traj): |
| if step['role'] == "assistant": |
| try: |
| relevance_idx = int(i//2) |
| |
| |
| if relevance_idx >= len(traj_goal_pairs_with_relevance): |
| |
| |
| |
| |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: relevance_idx {relevance_idx} >= len(traj_goal_pairs_with_relevance) {len(traj_goal_pairs_with_relevance)}" |
| } |
| |
| relevance_value = traj_goal_pairs_with_relevance[relevance_idx].get('relevance', '') |
| |
| |
| new_traj[i]['useful'] = True if 'yes' in relevance_value.lower() else False |
| |
| except Exception as exc: |
| |
| |
| import traceback |
| traceback.print_exc() |
| assert False |
| |
| |
| |
| |
| |
|
|
| hs = [new_traj] |
| if False: |
| print("<"*100) |
| print("GOAL Relevance Input:\n") |
| print(goal_relevant_inp) |
| print() |
| print("GOAL Relevance Output:\n") |
| print(goal_relevant_response) |
| print() |
| print("Extracted GOAL Relevance Output:\n") |
| print(new_traj) |
| print(">"*100) |
| print() |
|
|
| if False: |
| print("|"*100) |
| print("OLD TRAJ:\n") |
| |
| for x in state_history: |
| print(x) |
| print() |
| print("NEW TRAJ:\n") |
| for x in new_traj: |
| print(x) |
| print("|"*100) |
| |
| if hs: |
| return { |
| "has_hs": True, |
| "hs": hs |
| } |
| else: |
| return { |
| "has_hs": False, |
| "hs": f"ERROR: Failed to generate hindsight trajectories" |
| } |
| |
| def _generate_webshop_trajgoal_pairs(self, trajs): |
| """ |
| Generate trajectory-goal pairs for WebShop. |
| Similar to generate_trajgoal_pairs but adapted for WebShop format. |
| """ |
| traj_goal_pairs = [] |
| current_goals = [] |
| |
| for step in trajs: |
| if 'reached_goals' in step and step['reached_goals']: |
| |
| new_goals = [g for g in step['reached_goals'] if g not in current_goals] |
| for goal in new_goals: |
| |
| goal_start_idx = 0 |
| for i, s in enumerate(trajs[:step['step']]): |
| if 'reached_goals' in s and goal in s['reached_goals']: |
| goal_start_idx = i |
| break |
| |
| |
| traj_text = "" |
| for s in trajs[goal_start_idx:step['step']]: |
| traj_text += f"Step {s['step']}: Action=\"{s['action']}\"; Observation=\"{s['observation']}\"\n" |
| if 'page_type' in s: |
| traj_text += f" Page: {s['page_type']}" |
| if 'products_viewed' in s and s['products_viewed']: |
| traj_text += f", Products viewed: {len(s['products_viewed'])}" |
| traj_text += "\n" |
| |
| traj_goal_pairs.append({ |
| 'trajectory': traj_text, |
| 'goal': goal |
| }) |
| current_goals.extend(new_goals) |
| |
| return traj_goal_pairs |
| |
| def _extract_webshop_env_desc(self, original_task): |
| """ |
| Extract the WebShop task description (product search query). |
| """ |
| if isinstance(original_task, str): |
| |
| if "Instruction:" in original_task: |
| return original_task.split("Instruction:")[-1].split("\n")[0].strip() |
| return original_task.strip() |
| return "Shop for products online" |
| |
| def _update_webshop_goal(self, trajectory, new_goal): |
| """ |
| Update ALL WebShop instructions in the trajectory with the hindsight goal. |
| WebShop repeats the instruction in every human message. |
| """ |
| updated_trajectory = [] |
| |
| for turn in trajectory: |
| updated_turn = copy.deepcopy(turn) |
| |
| |
| if turn.get("role") == "user" and "Instruction: [SEP]" in turn.get("content", ""): |
| content = turn["content"] |
| |
| |
| inst_start = content.find("Instruction: [SEP]") |
| prefix = content[:inst_start + len("Instruction: [SEP]")] |
| |
| |
| remaining = content[len(prefix):] |
| next_sep = remaining.find("[SEP]") |
| |
| if next_sep != -1: |
| |
| suffix = remaining[next_sep:] |
| updated_turn["content"] = f"{prefix} {new_goal} {suffix}" |
| else: |
| |
| updated_turn["content"] = f"{prefix} {new_goal}" |
| |
| updated_trajectory.append(updated_turn) |
| |
| return updated_trajectory |
|
|