heendung's picture
Upload folder using huggingface_hub
d1c897a verified
import base64, copy
import os
import time
import re
from openai import OpenAI
from eval_agent.prompt import prompt_naive, prompt_with_icl
from .prompts import ALFWORLD_GOAL_SYSTEM, ALFWORLD_ALLGOAL_SYSTEM, ALFWORLD_ATOMICGOAL_SYSTEM, ALFWORLD_RELEVANT_SYSTEM, ALFWORLD_ATOMICGOAL_NOSUGAR_SYSTEM, ALFWORLD_RELEVANT_STATE_SYSTEM, ALFWORLD_ALLGOAL_STATE_SYSTEM
from .prompts_webshop import WEBSHOP_GOAL_SYSTEM, WEBSHOP_QUERY_SYSTEM, WEBSHOP_QUERY_FREE_SYSTEM, WEBSHOP_RELEVANT_SYSTEM, WEBSHOP_GOAL_BREW_SYSTEM
from .hf_utils import history_to_sft_sample
from abc import ABC, abstractmethod
import json, random
from typing import List, Dict, Any
from .utils import map_to_random_milestone
def extract_json_list(text):
"""Extract JSON list from text response."""
pattern = r'\[.*?\]'
match = re.search(pattern, text, re.DOTALL)
if match:
json_str = match.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None
return None
def extract_webshop_query_list(text):
"""
Extract JSON list from WebShop query response, handling measurement notations.
Specifically handles patterns like 3'3" x 5'3" in product descriptions.
"""
pattern = r'\[.*?\]'
match = re.search(pattern, text, re.DOTALL)
if match:
json_str = match.group(0)
# First, fix newlines within JSON strings by replacing them with spaces
# This handles cases where the LLM puts actual newlines in the JSON array elements
def replace_newlines_in_strings(m):
# Replace newlines with spaces inside the string
return m.group(0).replace('\n', ' ').replace('\r', ' ')
# Find all quoted strings and fix newlines inside them
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', replace_newlines_in_strings, json_str, flags=re.DOTALL)
# Fix common issue where strings are missing closing quotes before commas
# Pattern: ...text, "next text... should be ...text", "next text...
# This happens when LLM forgets to close quotes
json_str = re.sub(r'([^"\\]),\s*"', r'\1", "', json_str)
# Fix measurement notations inside JSON strings
# We need to find patterns like 3'3" that appear inside quoted strings
# and escape the inner quotes
# Handle foot-inch patterns like 3'3"
if re.search(r"\d+'\d+\"", json_str):
json_str = re.sub(r'"([^"]*?)(\d+\'[\d]+)"([^"]*?)"', r'"\1\2\\"\3"', json_str)
# Handle multiple measurements
while re.search(r'"([^"]*?)(\d+\'[\d]+)"([^"]*?)"', json_str):
json_str = re.sub(r'"([^"]*?)(\d+\'[\d]+)"([^"]*?)"', r'"\1\2\\"\3"', json_str)
# Handle inch-only patterns like 63"
if re.search(r'\d+"', json_str):
json_str = re.sub(r'"([^"]*?)(\d+)"([^"]*?)"', r'"\1\2\\"\3"', json_str)
# Handle multiple measurements
while re.search(r'"([^"]*?)(\d+)"([^"]*?)"', json_str):
json_str = re.sub(r'"([^"]*?)(\d+)"([^"]*?)"', r'"\1\2\\"\3"', json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
print(f"Problematic JSON string: {json_str[:500]}...")
return None
return None
def parse_vitamin_output(text: str, *, return_json_safe: bool = False) -> List[Dict[str, Any]]:
"""
Robustly extract & parse the JSON list from messy LLM output.
Handles:
- Any backtick code fences (>=3), with/without language tag (```json / ````json / etc.)
- Leading prose
- // and /* ... */ comments
- Trailing commas before ] or }
- Tuple-like pairs in placement: [("a","b")] -> [["a","b"]]
- Normalizes tool_status 'watersource_on' -> 'faucet_on'
If return_json_safe=True: converts placement tuples back to lists for json.dumps.
"""
def _remove_comments(src: str) -> str:
out, i, n = [], 0, len(src)
in_str = False; quote = ''; line_cmt = False; block_cmt = False
while i < n:
c = src[i]
if line_cmt:
if c == '\n': line_cmt = False; out.append(c)
i += 1; continue
if block_cmt:
if c == '*' and i+1 < n and src[i+1] == '/': block_cmt = False; i += 2
else: i += 1
continue
if not in_str:
if c == '/' and i+1 < n:
if src[i+1] == '/': line_cmt = True; i += 2; continue
if src[i+1] == '*': block_cmt = True; i += 2; continue
if c in ('"', "'"): in_str = True; quote = c
out.append(c)
else:
out.append(c)
if c == '\\' and i+1 < n:
i += 1; out.append(src[i])
elif c == quote:
in_str = False
i += 1
return ''.join(out)
def _remove_trailing_commas(s: str) -> str:
out, i, n = [], 0, len(s)
in_str = False; quote = ''
while i < n:
c = s[i]
if not in_str:
if c in ('"', "'"):
in_str = True; quote = c; out.append(c)
elif c == ',':
j = i+1
while j < n and s[j] in ' \t\r\n': j += 1
if j < n and s[j] in '}]': i += 1; continue
out.append(c)
else:
out.append(c)
else:
out.append(c)
if c == '\\' and i+1 < n:
i += 1; out.append(s[i])
elif c == quote:
in_str = False
i += 1
return ''.join(out)
def _extract_code_fence(t: str) -> str:
# Support backtick fences of length >=3, optional language
fence_pattern = re.compile(r"(`{3,})(?:[a-zA-Z]+)?\s*(.*?)\s*\1", re.DOTALL)
blocks = fence_pattern.findall(t)
for _, block in blocks:
if '[' in block and ']' in block:
return block
return t
def _extract_json_list(t: str) -> str:
t = _extract_code_fence(t)
start = t.find('[')
if start == -1:
raise ValueError("No JSON list found.")
depth = 0; end = None
for i in range(start, len(t)):
if t[i] == '[': depth += 1
elif t[i] == ']':
depth -= 1
if depth == 0: end = i; break
if end is None:
raise ValueError("Unbalanced brackets.")
return t[start:end+1]
def _fix_tuple_pairs(s: str) -> str:
# ("x", "y") -> ["x", "y"] (quoted string pairs only)
return re.sub(r'\(\s*("([^"\\]|\\.)*")\s*,\s*("([^"\\]|\\.)*")\s*\)', r'[\1, \3]', s)
blob = _extract_json_list(text)
blob = _remove_comments(blob)
blob = _fix_tuple_pairs(blob)
blob = _remove_trailing_commas(blob)
data = json.loads(blob)
# Normalize fields
for obj in data:
if isinstance(obj, dict) and "tool_status" in obj and isinstance(obj["tool_status"], dict):
ts = obj["tool_status"]
if "watersource_on" in ts and "faucet_on" not in ts:
ts["faucet_on"] = bool(ts.pop("watersource_on"))
ts.setdefault("lamp_on", False)
ts.setdefault("faucet_on", False)
ts.setdefault("microwave_on", False)
ts.setdefault("fridge_closed", False)
if isinstance(obj, dict) and "placement" in obj and isinstance(obj["placement"], list):
obj["placement"] = [tuple(p) for p in obj["placement"] if isinstance(p, (list, tuple)) and len(p) == 2]
if return_json_safe:
for obj in data:
if isinstance(obj, dict) and "placement" in obj:
obj["placement"] = [list(p) if isinstance(p, tuple) else p for p in obj["placement"]]
return data
# learn/hsl/relabel.py
def steps_to_hf_history(steps: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""
Convert a sequence of step dictionaries into a chat history for use with
history_to_sft_sample. The first observation in `steps` becomes the initial
user message, and each subsequent action becomes an assistant message followed
by the next observation as a user message.
"""
if not isinstance(steps, list) or not steps:
print("WARNING: the steps are empty!!!!!!!!!!!!!!!!!!")
print("Steps")
print(steps)
return []
history: List[Dict[str, str]] = []
for i, step in enumerate(steps):
# Normalise the action field into a string
action_text = step.get("action", "")
if isinstance(action_text, list):
# If the action is a list of strings, join them; otherwise dump as JSON
try:
action_text = " ".join(str(x) for x in action_text)
except Exception:
action_text = json.dumps(action_text)
else:
action_text = str(action_text)
if not action_text.startswith("Action:"):
action_text = "Action: " + action_text
# Normalise the relevance flag (default to 'yes' if missing or malformed)
is_rel = step.get("is_relevant_to_goal", "no")
if not isinstance(is_rel, str):
is_rel = "no"
is_rel = is_rel.lower()
history.append({
"role": "assistant",
"content": action_text,
# mark as useful only if the string contains "yes"
"useful": "yes" in is_rel
})
# Normalise the observation field into a string
obs_text = step.get("observation", "")
if isinstance(obs_text, list):
try:
obs_text = " ".join(str(x) for x in obs_text)
except Exception:
obs_text = json.dumps(obs_text)
else:
obs_text = str(obs_text)
if not obs_text.startswith("Observation:"):
obs_text = "Observation: " + obs_text
history.append({
"role": "user",
"content": obs_text
})
return history
def extract_env_desc(original_task):
original_task = original_task.split("\n")
if "Your task is to:" in original_task:
original_task = original_task[:-1]
return "\n".join(original_task)
def generate_trajgoal_pairs(steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Generate training samples (input-target pairs) from a list of step dictionaries.
Each training sample corresponds to a distinct goal that the agent achieved.
For a goal first achieved at step k, the input is the list of step dictionaries
from the beginning up to and including step k, and the target is the goal
description. If no goal was achieved, this function returns an empty list.
Args:
steps: A list of dictionaries representing the parsed JSON output. Each
dictionary in the list (except possibly the last) should contain
keys 'step', 'action', 'observation', 'location', 'inventory', and
'reached_goals'. A final dictionary may contain keys 'final_goals'
and 'explanation' instead.
Returns:
A list of dictionaries with keys:
'input': list of step dictionaries up to and including the step
where a goal is first achieved,
'target': the goal string achieved at that step.
If no goals are found, returns an empty list.
Raises:
ValueError: If the input is not a list or contains no step dictionaries.
"""
if not isinstance(steps, list) or not steps:
raise ValueError("Input must be a non-empty list of step dictionaries.")
# If no final_goals were declared, infer goals from all steps in order of first appearance
seen_goals_order = []
for step in steps:
for goal in step.get('reached_goals', []):
if goal not in seen_goals_order:
seen_goals_order.append(goal)
final_goals = seen_goals_order
# If still no goals, return an empty list
if not final_goals:
return []
# Determine the first step index at which each goal is achieved
goal_first_index: Dict[str, int] = {}
achieved_goals = set()
for idx, step in enumerate(steps):
for goal in step.get('reached_goals', []):
# Record only the first occurrence of each goal
if goal not in achieved_goals:
achieved_goals.add(goal)
goal_first_index[goal] = idx
# Generate training pairs in the order specified by final_goals
traj_goal_pairs = []
for goal in final_goals:
if goal in goal_first_index:
idx = goal_first_index[goal]
# Input: all steps up to and including the step where this goal is first achieved
prefix_steps = steps[: idx + 1]
traj_goal_pairs.append({
"trajectory": prefix_steps,
"goal": goal
})
return traj_goal_pairs
"""
def extract_json_objects_from_output(output: str) -> List[Dict[str, Any]]:
# Regex captures the first '[' … ']' sequence containing at least one object.
match = re.search(r'\[\s*{.*?}\s*\]', output, re.DOTALL)
if not match:
print(output)
print("ValueError: No JSON array detected in the output string. Please check the LLM output format.")
return []
json_str = match.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError as exc:
print(output)
print(f"ValueError: Failed to decode JSON array: {exc}")
return []
"""
def extract_json_objects_from_output(output: str) -> List[Dict[str, Any]]:
"""
Extract the first well-balanced JSON array from a blob of text, clean
out common artifacts (ellipsis, trailing commas), and parse it.
"""
# 1) locate the first '['
start = output.find('[')
if start == -1:
raise ValueError("No '[' found in output; can't extract JSON array.")
# 2) walk forward to match brackets, ignoring ones inside strings
in_str = False
escape = False
depth = 0
end = None
for idx, ch in enumerate(output[start:], start):
if in_str:
if escape:
escape = False
elif ch == '\\':
escape = True
elif ch == '"':
in_str = False
else:
if ch == '"':
in_str = True
elif ch == '[':
depth += 1
elif ch == ']':
depth -= 1
if depth == 0:
end = idx + 1
break
if end is None:
raise ValueError("Unbalanced brackets: never found matching ']'.")
raw = output[start:end]
# 3a) remove // comments
raw = re.sub(r'//.*$', '', raw, flags=re.MULTILINE)
# 3b) remove standalone ellipsis lines
raw = re.sub(r'^\s*\.\.\..*$\n?', '', raw, flags=re.MULTILINE)
# 3c) strip trailing commas before } or ]
raw = re.sub(r',\s*(?=[}\]])', '', raw)
try:
return json.loads(raw)
except json.JSONDecodeError as exc:
# include the snippet and error for easier debugging
raise ValueError(f"Failed to decode JSON.\nError: {exc}")
def extract_webshop_json_objects(output: str) -> List[Dict[str, Any]]:
"""
WebShop-specific JSON extraction that handles common issues with WebShop data.
Specifically handles:
- Apostrophes in product names like "salt n' vinegar"
- Measurement notations like 26", 16" (inches) in search actions
- Extra text after the JSON array ends
"""
# Find first [ to start the JSON array
start = output.find('[')
if start == -1:
raise ValueError("No '[' found in output; can't extract JSON array.")
# Track bracket depth to find where the array actually ends
bracket_depth = 0
in_string = False
escape_next = False
end = -1
for i in range(start, len(output)):
char = output[i]
if escape_next:
escape_next = False
continue
if char == '\\':
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == '[':
bracket_depth += 1
elif char == ']':
bracket_depth -= 1
if bracket_depth == 0:
end = i + 1
break
if end == -1:
raise ValueError("No matching ']' found in output; can't extract JSON array.")
raw = output[start:end]
# CRITICAL FIX: Handle measurement quotes in WebShop data
# LLM may output either 16" (unescaped) or 16\" (escaped) in measurements
# Step 1: Protect already-escaped quotes by replacing \\" with placeholder
raw = raw.replace('\\\\"', '__ESCAPED_QUOTE__')
# Step 2: Fix unescaped measurement quotes
# Match pattern: digit + quote + space/letter (common in measurements like "16" laptop")
raw = re.sub(r'(\d+(?:\.\d+)?)"(\s+\w|\w)', r'\1\\"\2', raw)
# Step 3: Restore the placeholders
raw = raw.replace('__ESCAPED_QUOTE__', '\\"')
# 3a) remove // comments
raw = re.sub(r'//.*$', '', raw, flags=re.MULTILINE)
# 3b) remove standalone ellipsis lines
raw = re.sub(r'^\s*\.\.\..*$\n?', '', raw, flags=re.MULTILINE)
# 3c) strip trailing commas before } or ]
raw = re.sub(r',\s*(?=[}\]])', '', raw)
# 3d) WebShop-specific fixes for common issues (quotes already fixed above)
# Fix escaped apostrophes (they don't need escaping in JSON)
raw = raw.replace(r"\'", "'")
# Fix unicode apostrophe if it appears
raw = re.sub(r'\\u0027', "'", raw)
try:
return json.loads(raw)
except json.JSONDecodeError as exc:
# If still failing, try additional cleanup
try:
# Sometimes LLMs add extra escaping to already valid JSON
# Try removing excessive backslashes
cleaned = re.sub(r'\\([\'"])', r'\1', raw)
return json.loads(cleaned)
except:
# include the snippet and error for easier debugging
print(f"Raw JSON that failed to parse:\n{raw[:500]}...")
raise ValueError(f"Failed to decode JSON.\nError: {exc}")
def extract_shopbrew_dict_from_output(output: str) -> Dict[str, Any]:
"""
Extract the first well-balanced JSON dictionary from a blob of text, clean
out common artifacts (ellipsis, trailing commas, comments), and parse it.
Similar to extract_json_objects_from_output but for dictionaries instead of lists.
Handles both single-quoted Python dicts and double-quoted JSON.
"""
# 1) locate the first '{'
start = output.find('{')
if start == -1:
raise ValueError("No '{' found in output; can't extract JSON dictionary.")
# 2) walk forward to match braces, ignoring ones inside strings
in_str = False
escape = False
depth = 0
end = None
quote_char = None # Track which quote type we're in
for idx, ch in enumerate(output[start:], start):
if in_str:
if escape:
escape = False
elif ch == '\\':
escape = True
elif ch == quote_char: # End of string when we see the matching quote
in_str = False
quote_char = None
else:
if ch in ('"', "'"): # Start of string with either quote type
in_str = True
quote_char = ch
elif ch == '{':
depth += 1
elif ch == '}':
depth -= 1
if depth == 0:
end = idx + 1
break
if end is None:
raise ValueError("Unbalanced braces: never found matching '}'.")
raw = output[start:end]
# 3a) remove // comments
raw = re.sub(r'//.*$', '', raw, flags=re.MULTILINE)
# 3b) remove standalone ellipsis lines
raw = re.sub(r'^\s*\.\.\..*$\n?', '', raw, flags=re.MULTILINE)
# 3c) strip trailing commas before } or ]
raw = re.sub(r',\s*(?=[}\]])', '', raw)
# 4) Try to parse as JSON first
try:
return json.loads(raw)
except json.JSONDecodeError:
# 5) If it fails, it might be a Python dict with single quotes
# First, fix boolean values (true/false -> True/False) and null -> None for Python
fixed_bools = raw
# Replace lowercase booleans and null that are not inside quotes
# This regex looks for true/false/null that are preceded by : or , and followed by , or }
fixed_bools = re.sub(r':\s*true\b', ': True', fixed_bools)
fixed_bools = re.sub(r':\s*false\b', ': False', fixed_bools)
fixed_bools = re.sub(r',\s*true\b', ', True', fixed_bools)
fixed_bools = re.sub(r',\s*false\b', ', False', fixed_bools)
fixed_bools = re.sub(r':\s*null\b', ': None', fixed_bools)
fixed_bools = re.sub(r',\s*null\b', ', None', fixed_bools)
try:
# Use ast.literal_eval for Python dict format
import ast
result = ast.literal_eval(fixed_bools)
if isinstance(result, dict):
return result
else:
raise ValueError(f"Parsed result is not a dictionary: {type(result)}")
except (ValueError, SyntaxError) as exc:
# If ast.literal_eval also fails, try converting to JSON format
try:
# Convert Python-style to JSON: single quotes to double, True/False to true/false, None to null
json_str = fixed_bools.replace("'", '"')
json_str = re.sub(r'\bTrue\b', 'true', json_str)
json_str = re.sub(r'\bFalse\b', 'false', json_str)
json_str = re.sub(r'\bNone\b', 'null', json_str)
return json.loads(json_str)
except json.JSONDecodeError as exc2:
raise ValueError(f"Failed to decode JSON or Python dict.\nJSON Error: {exc2}\nAST Error: {exc}")
def remove_number(goal):
clean = re.sub(r'\b\d+\b', '', goal) # remove whole numbers
clean = " ".join(clean.split()) # collapse any extra spaces
return clean
class BaseJuicer(ABC):
"""Base class for relabeling strategies using an OpenAI client."""
def __init__(self, llama: str = "llama-3-1-70b", api: str = "internal") -> None:
if api == "internal":
endpoint = "http://pluto-prod-jackwa-llama-70b-infere-0:8000/v1"
key = "sk-QObXQcx0GTDciGVNhkgTLw"
api_key = "Bearer " + key
self.llama = "meta-llama/Llama-3.3-70B-Instruct"
elif api == "second":
endpoint = "http://pluto-prod-jkil-job2-0:8000/v1"
key = "sk-QObXQcx0GTDciGVNhkgTLw"
api_key = "Bearer " + key
self.llama = "meta-llama/Llama-3.3-70B-Instruct"
elif api == "legacy":
endpoint = "http://pluto-prod-hawang-llm-proxy-9qtfav-0:4000"
key = "sk--Tz-ELdSOZ-MFBRRomTuGg"
api_key = "Bearer " + key
self.llama = "gpt-4.1-mini"
elif api == "openrouter":
endpoint = "https://openrouter.ai/api/v1"
api_key = "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"
self.llama = "openai/gpt-4.1-mini"
elif api == "gpt":
endpoint = "https://api.openai.com/v1"
api_key = "sk-uuhu_uog_hb7OYM3xrzT_fFO-1KPUxPu2VOySrIXZGT3BlbkFJwPSRtJU5lEuyoWrf1Ayh8cEW265ABrvuCm6fwEAxUA"
self.llama = "gpt-4.1-mini-2025-04-14"
else:
raise ValueError(f"Unknown API option: {api}")
if api != "legacy":
self.llama_gen_params = {"temperature": 0}
else:
self.llama_gen_params = {}
self.client = OpenAI(api_key=api_key, base_url=endpoint)
# Optional logging configuration. ``relabel_log_file`` is the path to append
# JSON records of requests/responses, ``log_interval`` controls how
# frequently to log based on ``step_getter`` (default every 20 steps).
# ``step_getter`` should be a callable returning the current training
# step. These attributes may be set after construction.
self.relabel_log_file: str | None = None
self.mask_log_file: str | None = None
self.log_interval: int = 30
self.step_getter = None
def _chat_completion(self, messages, mode):
"""Invoke the OpenAI client with retry logic and optional logging."""
attempt = 0
while True:
try:
resp = self.client.chat.completions.create(
model=self.llama, messages=messages, **self.llama_gen_params
)
resp_text = resp.choices[0].message.content
break
except Exception as exc:
attempt += 1
if attempt >= 3:
print(f"Failed to send request: {exc}")
raise
time.sleep(1)
continue
step = self.step_getter()
if step % self.log_interval == 0:
fn = self.relabel_log_file if mode == 'relabel' else self.mask_log_file
with open(fn, "a", encoding="utf-8") as f:
json.dump({
"step": step,
"messages": messages,
"response": resp_text,
}, f)
f.write("\n")
return resp_text
def _rectify_chat_completion(self, original_messages, llm_response, error_msg):
"""
Request a corrected completion from the model.
Parameters
----------
original_messages : List[dict]
Messages originally sent to the LLM.
llm_response : str
The model output that failed to parse.
error_msg : str
Explanation of the parsing failure.
Returns
-------
str
The model's new response after being asked to rectify the output.
"""
error_msg = "Your original output is in a WRONG format, and here is the error message "
messages = (
original_messages
+ [{"role": "assistant", "content": llm_response}]
+ [{"role": "user", "content": error_msg}]
)
return self._chat_completion(messages)
@abstractmethod
def relabel_experience(self, state_history, obs, llm_out, **kwargs):
"""Return relabeled trajectory based on ``obs`` and ``llm_out``."""
raise NotImplementedError
class Juicer(BaseJuicer):
def __init__(self, task_instruction, llama="llama-3-1-70b", api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
def relabel_experience(self, state_history, obs, llm_out, **kwargs):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": ALFWORLD_GOAL_SYSTEM},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
if False:
print()
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("-"*100)
print()
if "Final goal: " in chat_response and not "final goal: none" in chat_response.lower():
new_goal = chat_response.split("Final goal: ")[-1]
new_goal = remove_number(new_goal)
else:
return {
"has_hs": False,
"hs": chat_response
} #new_goal = chat_response.split("\n")[-1]
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {new_goal}"
if False:
print("*"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj[0])
print("*"*100)
return {
'has_hs': True,
"hs": new_traj
}
class Lemonade(BaseJuicer):
'''
one sigle goal for each trajectory, mask out the irrelevant actions
first_only: only include the first achieved goal
'''
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.first_only = first_only
self._relabel_prompt = ALFWORLD_ALLGOAL_SYSTEM
self._relevant_prompt = ALFWORLD_RELEVANT_SYSTEM
def relabel_experience(self, state_history, obs, llm_out, original_task):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": self._relabel_prompt},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
try:
trajs = extract_json_objects_from_output(chat_response)
except Exception as exc:
print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}")
trajs = []
if not trajs:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
if "final_goals" in trajs[-1].keys():
trajs_final = trajs[-1]
trajs = trajs[:-1]
else:
print("WARNING: No finals goal from relabel model!!!!!")
print(trajs)
trajs_final = None
if False:
print("Trajs")
print(trajs)
assert False
traj_goal_pairs = generate_trajgoal_pairs(trajs)
if len(traj_goal_pairs) == 0:
return {
"has_hs": False,
"hs": chat_response
}
if self.first_only:
traj_goal_pairs = [traj_goal_pairs[0]]
if False:
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("-"*100)
print()
env_desc = extract_env_desc(original_task)
hs = []
for step in traj_goal_pairs:
traj = step['trajectory']
goal = remove_number(step['goal'])
goal_relevant_inp = [
{"role": "system", "content": self._relevant_prompt},
{
"role": "user",
"content": f"Environment description:\n{env_desc}\n\nGoal:{goal}\n\nHere is the tracjtory:\n{traj}\n\nNow, please judge the relevance of actions at each step."
},
]
goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask")
try:
traj_goal_pairs_with_relevance = extract_json_objects_from_output(goal_relevant_response)
except Exception as exc:
print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_inp}")
continue
if False:
print("<"*100)
print("GOAL Relevance Input:\n")
print(goal_relevant_inp)
print()
print("GOAL Relevance Output:\n")
print(goal_relevant_response)
print()
print("Extracted GOAL Relevance Output:\n")
print(traj_goal_pairs_with_relevance)
print(">"*100)
print()
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {goal}"
new_traj = [new_traj[0]] + steps_to_hf_history(traj_goal_pairs_with_relevance)
hs.append(new_traj)
if False:
print("|"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj)
print("|"*100)
if hs:
return {
"has_hs": True,
"hs": hs
}
else:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
class Vitamin(Lemonade):
"""
Same behavior as `Lemonade` but uses different system prompts.
Only overrides the prompts; everything else stays identical.
"""
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
# Initialize exactly like Lemonade/BaseJuicer
super().__init__(task_instruction, llama=llama, first_only=first_only, api=api)
# Swap in the alternate prompts
self._relabel_prompt = ALFWORLD_ALLGOAL_STATE_SYSTEM
self._relevant_prompt = ALFWORLD_RELEVANT_STATE_SYSTEM
def relabel_experience(self, state_history, obs, llm_out, original_task):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": self._relabel_prompt},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
try:
trajs = parse_vitamin_output(chat_response)
except Exception as exc:
print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}")
trajs = []
if not trajs:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
if "final_goals" in trajs[-1].keys():
trajs_final = trajs[-1]
trajs = trajs[:-1]
else:
print("WARNING: No finals goal from relabel model!!!!!")
print(trajs)
trajs_final = None
if False:
print("Trajs")
print(trajs)
assert False
traj_goal_pairs = generate_trajgoal_pairs(trajs)
if len(traj_goal_pairs) == 0:
return {
"has_hs": False,
"hs": chat_response
}
if self.first_only:
traj_goal_pairs = [traj_goal_pairs[0]]
if False:
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("TRAJ\n")
print(trajs)
print("-"*100)
print()
env_desc = extract_env_desc(original_task)
hs = []
for step in traj_goal_pairs:
traj = step['trajectory']
goal = remove_number(step['goal'])
goal_relevant_inp = [
{"role": "system", "content": self._relevant_prompt},
{
"role": "user",
"content": f"Environment description:\n{env_desc}\n\nGoal:{goal}\n\nHere is the tracjtory:\n{traj}\n\nNow, please judge the relevance of actions at each step."
},
]
goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask")
try:
traj_goal_pairs_with_relevance = extract_json_objects_from_output(goal_relevant_response)
except Exception as exc:
print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_inp}")
assert False
continue
if False:
print("<"*100)
print("GOAL Relevance Input:\n")
print(goal_relevant_inp)
print()
print("GOAL Relevance Output:\n")
print(goal_relevant_response)
print()
print("Extracted GOAL Relevance Output:\n")
print(traj_goal_pairs_with_relevance)
print(">"*100)
print()
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {goal}"
new_traj = [new_traj[0]] + steps_to_hf_history(traj_goal_pairs_with_relevance)
hs.append(new_traj)
if False:
print("|"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj)
print("|"*100)
if hs:
return {
"has_hs": True,
"hs": hs
}
else:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
class OTea(BaseJuicer):
'''
one sigle goal for each trajectory, mask out the irrelevant actions
first_only: only include the first achieved goal
'''
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.first_only = first_only
def distill(self, new_trajectory):
'''
re-examine the trajectory with relabled goal
'''
return
def relabel_experience(self, state_history, obs, llm_out, original_task):
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
if x.startswith("Observation: Error Input."):
continue
action = y.split("Action: ")[-1]
action = f'Action="{action}"'
obs = x.split("Observation: ")[-1]
obs = f'Observation="{obs}"'
act_obs_traj += f"Step {i+1}: {action}; {obs}.\n"
#if act_obs_traj == '':
# return None
relabel_inp = [
{"role": "system", "content": ALFWORLD_ATOMICGOAL_NOSUGAR_SYSTEM},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
try:
trajs = extract_json_objects_from_output(chat_response)
except Exception as exc:
print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}")
trajs = []
if not trajs:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
if "final_goals" in trajs[-1].keys():
trajs_final = trajs[-1]
trajs = trajs[:-1]
else:
print("WARNING: No finals goal from relabel model!!!!!")
print(trajs)
trajs_final = None
if False:
print("Trajs")
print(trajs)
assert False
traj_goal_pairs = generate_trajgoal_pairs(trajs)
if len(traj_goal_pairs) == 0:
return {
"has_hs": False,
"hs": chat_response
}
if False and self.first_only:
traj_goal_pairs = [traj_goal_pairs[0]]
if False:
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print("-"*100)
print()
env_desc = extract_env_desc(original_task)
hs = []
for step in traj_goal_pairs:
traj = step['trajectory']
goal = remove_number(step['goal'])
goal_relevant_inp = [
{"role": "system", "content": ALFWORLD_RELEVANT_SYSTEM},
{
"role": "user",
"content": f"Environment description:\n{env_desc}\n\nGoal:{goal}\n\nHere is the tracjtory:\n{traj}\n\nNow, please judge the relevance of actions at each step."
},
]
goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask")
try:
traj_goal_pairs_with_relevance = extract_json_objects_from_output(goal_relevant_response)
except Exception as exc:
print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_inp}")
continue
if False:
print("<"*100)
print("GOAL Relevance Input:\n")
print(goal_relevant_inp)
print()
print("GOAL Relevance Output:\n")
print(goal_relevant_response)
print()
print("Extracted GOAL Relevance Output:\n")
print(traj_goal_pairs_with_relevance)
print(">"*100)
print()
new_traj = copy.deepcopy(state_history)
new_traj[0]['content'] = "\n".join(new_traj[0]['content'].split("\n")[:-1])+f"\nYour task is to: {goal}"
new_traj = [new_traj[0]] + steps_to_hf_history(traj_goal_pairs_with_relevance)
hs.append(new_traj)
if False:
print("|"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj)
print("|"*100)
if hs:
return {
"has_hs": True,
"hs": hs
}
else:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
def build_webshop_trajectory(obs, llm_out):
"""
Build action-observation trajectory string for WebShop with clean website observations.
Removes instruction prefixes to keep only actual website content.
"""
import re
act_obs_traj = ''
for i, (x, y) in enumerate(zip(obs, llm_out)):
# Handle WebShop-specific observation format
if "Error" in x or "error" in x:
continue
# Extract action from LLM output
if "Action: " in y:
action = y.split("Action: ")[-1].strip()
else:
action = y.strip()
action = f'Action="{action}"'
# Extract observation
if "Observation:" in x:
obs_text = x.split("Observation:")[-1].strip()
else:
obs_text = x.strip()
# Remove instruction prefix using regex: "Instruction: [SEP] ... [SEP]"
obs_text = re.sub(r'^Instruction:\s*\[SEP\].*?\[SEP\]\s*', '', obs_text)
# Remove reward information at the end: "[SEP] Reward [SEP] ... to end"
obs_text = re.sub(r'\[SEP\]\s*Reward\s*\[SEP\].*$', '', obs_text)
obs_text = f'Observation="{obs_text}"'
act_obs_traj += f"Step {i+1}: {action}; {obs_text}.\n"
return act_obs_traj
def build_webshop_trajectory_and_intention(obs, llm_out, intentions_traj):
"""
Build action-observation trajectory string for WebShop with clean website observations.
Removes instruction prefixes to keep only actual website content.
"""
import re
act_obs_intent_tracj = ""
for i, (x, y, z) in enumerate(zip(obs, llm_out, intentions_traj)):
# Handle WebShop-specific observation format
if "Error" in x or "error" in x:
continue
# Extract action from LLM output
if "Action: " in y:
action = y.split("Action: ")[-1].strip()
else:
action = y.strip()
action = f'Action="{action}"'
# Extract observation
if "Observation:" in x:
obs_text = x.split("Observation:")[-1].strip()
else:
obs_text = x.strip()
# Remove instruction prefix using regex: "Instruction: [SEP] ... [SEP]"
obs_text = re.sub(r'^Instruction:\s*\[SEP\].*?\[SEP\]\s*', '', obs_text)
# Remove reward information at the end: "[SEP] Reward [SEP] ... to end"
obs_text = re.sub(r'\[SEP\]\s*Reward\s*\[SEP\].*$', '', obs_text)
obs_text = f'Observation="{obs_text}"'
intention_text = f'current_intention={z}'
act_obs_intent_tracj += f"Step {i+1}: {action}; {obs_text}; {intention_text}.\n"
return act_obs_intent_tracj
def webshop_steps_to_hf_history(steps: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""
Convert a sequence of step dictionaries into a chat history for use with
history_to_sft_sample. The first observation in `steps` becomes the initial
user message, and each subsequent action becomes an assistant message followed
by the next observation as a user message.
"""
if not isinstance(steps, list) or not steps:
print("WARNING: the steps are empty!!!!!!!!!!!!!!!!!!")
print("Steps")
print(steps)
return []
history: List[Dict[str, str]] = []
for i, step in enumerate(steps):
# Normalise the action field into a string
action_text = step.get("action", "")
if isinstance(action_text, list):
# If the action is a list of strings, join them; otherwise dump as JSON
try:
action_text = " ".join(str(x) for x in action_text)
except Exception:
action_text = json.dumps(action_text)
else:
action_text = str(action_text)
if not action_text.startswith("Action:"):
action_text = "Action: " + action_text
# Normalise the relevance flag (default to 'yes' if missing or malformed)
is_rel = step.get("is_relevant_to_goal", "no")
if not isinstance(is_rel, str):
is_rel = "no"
is_rel = is_rel.lower()
history.append({
"role": "assistant",
"content": action_text,
# mark as useful only if the string contains "yes"
"useful": "yes" in is_rel
})
# Normalise the observation field into a string
obs_text = step.get("observation", "")
if isinstance(obs_text, list):
try:
obs_text = " ".join(str(x) for x in obs_text)
except Exception:
obs_text = json.dumps(obs_text)
else:
obs_text = str(obs_text)
if not obs_text.startswith("Observation:"):
obs_text = "Observation: " + obs_text
history.append({
"role": "user",
"content": obs_text
})
return history
class ShopLemonade(BaseJuicer):
"""
WebShop-specific implementation of Lemonade relabeling.
Uses WebShop-specific prompts for e-commerce navigation tasks.
"""
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.first_only = first_only
# Use WebShop-specific prompts
self._relabel_prompt = WEBSHOP_GOAL_SYSTEM
self._query_prompt = WEBSHOP_QUERY_SYSTEM
def relabel_experience(self, state_history, obs, llm_out, original_task):
"""
Relabel WebShop trajectories with hindsight goals.
Adapted from Lemonade but handles WebShop's action/observation format.
"""
# Build trajectory with clean website observations
act_obs_traj = build_webshop_trajectory(obs, llm_out)
if not act_obs_traj:
return {
"has_hs": False,
"hs": "No valid action-observation pairs found"
}
# Call relabeling model with WebShop prompt
relabel_inp = [
{"role": "system", "content": self._relabel_prompt},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
try:
trajs = extract_json_objects_from_output(chat_response)
except Exception as exc:
print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}")
trajs = []
if not trajs:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
# Extract final goals
if "purchase_success" in trajs[-1].keys():
purchase_success = trajs[-1]['purchase_success']
trajs_final = trajs[-2]
trajs = trajs[:-1]
else:
print("WARNING: No final goals from relabel model!")
trajs_final = None
if not purchase_success:
return {
"has_hs": False,
"hs": f"Failed to buy anything!!!"
}
# Extract environment description (product search query)
trajs_final = {k:v for k,v in trajs_final.items() if "step" not in k.lower()}
query_inp = [
{"role": "system", "content": self._query_prompt},
{
"role": "user",
"content": str(trajs_final)
},
]
query_raw = self._chat_completion(query_inp, "mask")
# Extract JSON list from response
query = extract_json_list(query_raw)
if query is None:
print(f"ERROR: Failed to parse JSON list from query response: {query_raw}")
return {
"has_hs": False,
"hs": f"Failed to parse query!!!"
}
query = random.choice(query)
if False:
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print()
print("Query:\n")
print(query)
print("-"*100)
print()
new_traj = copy.deepcopy(state_history)
new_traj = self._update_webshop_goal(new_traj, query)
hs = [new_traj]
if False:
print("<"*100)
print("GOAL Relevance Input:\n")
print(goal_relevant_inp)
print()
print("GOAL Relevance Output:\n")
print(goal_relevant_response)
print()
print("Extracted GOAL Relevance Output:\n")
print(new_traj)
print(">"*100)
print()
if False:
print("|"*100)
print("OLD TRAJ:\n")
print(state_history)
print()
print("NEW TRAJ:\n")
print(new_traj)
print("|"*100)
if hs:
return {
"has_hs": True,
"hs": hs
}
else:
return {
"has_hs": False,
"hs": f"ERROR: Failed to generate hindsight trajectories"
}
def _generate_webshop_trajgoal_pairs(self, trajs):
"""
Generate trajectory-goal pairs for WebShop.
Similar to generate_trajgoal_pairs but adapted for WebShop format.
"""
traj_goal_pairs = []
current_goals = []
for step in trajs:
if 'reached_goals' in step and step['reached_goals']:
# Check for new goals reached at this step
new_goals = [g for g in step['reached_goals'] if g not in current_goals]
for goal in new_goals:
# Find where this goal was first reached
goal_start_idx = 0
for i, s in enumerate(trajs[:step['step']]):
if 'reached_goals' in s and goal in s['reached_goals']:
goal_start_idx = i
break
# Build trajectory up to current step
traj_text = ""
for s in trajs[goal_start_idx:step['step']]:
traj_text += f"Step {s['step']}: Action=\"{s['action']}\"; Observation=\"{s['observation']}\"\n"
if 'page_type' in s:
traj_text += f" Page: {s['page_type']}"
if 'products_viewed' in s and s['products_viewed']:
traj_text += f", Products viewed: {len(s['products_viewed'])}"
traj_text += "\n"
traj_goal_pairs.append({
'trajectory': traj_text,
'goal': goal
})
current_goals.extend(new_goals)
return traj_goal_pairs
def _extract_webshop_env_desc(self, original_task):
"""
Extract the WebShop task description (product search query).
"""
if isinstance(original_task, str):
# Extract the main search query/instruction
if "Instruction:" in original_task:
return original_task.split("Instruction:")[-1].split("\n")[0].strip()
return original_task.strip()
return "Shop for products online"
def _update_webshop_goal(self, trajectory, new_goal):
"""
Update ALL WebShop instructions in the trajectory with the hindsight goal.
WebShop repeats the instruction in every human message.
"""
updated_trajectory = []
for turn in trajectory:
updated_turn = copy.deepcopy(turn)
# Only update human messages that contain instructions
if turn.get("role") == "user" and "Instruction: [SEP]" in turn.get("content", ""):
content = turn["content"]
# Find where "Instruction: [SEP]" starts
inst_start = content.find("Instruction: [SEP]")
prefix = content[:inst_start + len("Instruction: [SEP]")]
# Find the next [SEP] after the goal
remaining = content[len(prefix):]
next_sep = remaining.find("[SEP]")
if next_sep != -1:
# Keep everything after the goal's [SEP]
suffix = remaining[next_sep:]
updated_turn["content"] = f"{prefix} {new_goal} {suffix}"
else:
# No trailing [SEP], just append the new goal
updated_turn["content"] = f"{prefix} {new_goal}"
updated_trajectory.append(updated_turn)
return updated_trajectory
class ShopOTea(BaseJuicer):
"""
WebShop-specific implementation of Lemonade relabeling.
Uses WebShop-specific prompts for e-commerce navigation tasks.
"""
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.first_only = first_only
# Use WebShop-specific prompts
self._relabel_prompt = WEBSHOP_GOAL_SYSTEM
self._query_prompt = WEBSHOP_QUERY_FREE_SYSTEM
self._relevant_prompt = WEBSHOP_RELEVANT_SYSTEM
def relabel_experience(self, state_history, obs, llm_out, original_task):
"""
Relabel WebShop trajectories with hindsight goals.
Adapted from Lemonade but handles WebShop's action/observation format.
"""
# Build trajectory with clean website observations
act_obs_traj = build_webshop_trajectory(obs, llm_out)
if not act_obs_traj:
return {
"has_hs": False,
"hs": "No valid action-observation pairs found"
}
# Call relabeling model with WebShop prompt
relabel_inp = [
{"role": "system", "content": self._relabel_prompt},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
try:
intentions = extract_json_objects_from_output(chat_response)
except Exception as exc:
print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}")
intentions = []
if not intentions:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
# Extract final goals
if "purchase_success" in intentions[-1].keys():
purchase_success = intentions[-1]['purchase_success']
intentions_final = intentions[-2]
intentions = intentions[:-1]
else:
print("WARNING: No final goals from relabel model!")
intentions_final = None
#print("purchase_success", purchase_success, type(purchase_success))
if not purchase_success:
print("WARNING: Failed to buy anything!!!")
return {
"has_hs": False,
"hs": "Failed to buy anything!!!"
}
intentions_final = {k:v for k,v in intentions_final.items() if "step" not in k.lower()}
act_obs_intent_traj = build_webshop_trajectory_and_intention(obs, llm_out, intentions)
#print("act_obs_intent_traj", act_obs_intent_traj)
goal_relevant_inp = [
{"role": "system", "content": self._relevant_prompt},
{
"role": "user",
"content": f"Shopping intention:{intentions_final}\n\nHere is the tracjtory:\n{act_obs_intent_traj}\n\nNow, please judge the relevance of actions at each step."
},
]
goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask")
try:
# Use WebShop-specific JSON extraction to handle apostrophes and other issues
traj_goal_pairs_with_relevance = extract_webshop_json_objects(goal_relevant_response)
except Exception as exc:
print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_response}")
assert False
# Extract environment description (product search query)
query_inp = [
{"role": "system", "content": self._query_prompt},
{
"role": "user",
"content": str(intentions_final)
},
]
query_raw = self._chat_completion(query_inp, "mask")
# Extract JSON list from response (use WebShop-specific version for measurement handling)
query = extract_webshop_query_list(query_raw)
if query is None:
print(f"ERROR: Failed to parse JSON list from query response: {query_raw}")
return {
"has_hs": False,
"hs": f"Failed to parse query!!!"
}
query = min(query, key=len).lower()
if False:
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print()
print("Query:\n")
print(query)
print("-"*100)
print()
new_traj = copy.deepcopy(state_history)
new_traj = self._update_webshop_goal(new_traj, query.lower())
print(f"DEBUG: Starting to add 'useful' key to trajectory with {len(new_traj)} messages")
print(f"DEBUG: traj_goal_pairs_with_relevance has {len(traj_goal_pairs_with_relevance)} items")
for i, step in enumerate(new_traj):
if step['role'] == "assistant":
try:
relevance_idx = int(i//2)
print(f"DEBUG: Processing assistant message at i={i}, using relevance_idx={relevance_idx}")
if relevance_idx >= len(traj_goal_pairs_with_relevance):
print(f"ERROR: relevance_idx {relevance_idx} >= len(traj_goal_pairs_with_relevance) {len(traj_goal_pairs_with_relevance)}")
print(f"DEBUG: traj_goal_pairs_with_relevance = {traj_goal_pairs_with_relevance}")
assert False
relevance_value = traj_goal_pairs_with_relevance[relevance_idx].get('relevance', '')
print(f"DEBUG: relevance_value = {relevance_value}")
new_traj[i]['useful'] = True if 'yes' in relevance_value.lower() else False
print(f"DEBUG: Set useful={new_traj[i]['useful']} for assistant message at i={i}")
except Exception as exc:
print(f"ERROR: Failed to extract the relevance label:\nError: {exc}\n{(int(i//2), len(traj_goal_pairs_with_relevance))}")
print(f"DEBUG: Full exception details: {exc}")
import traceback
traceback.print_exc()
assert False
print(f"DEBUG: Finished adding 'useful' key. Checking results...")
for i, msg in enumerate(new_traj):
if msg['role'] == 'assistant':
print(f" Assistant message {i}: has_useful={'useful' in msg}, useful={msg.get('useful', 'NOT SET')}")
hs = [new_traj]
if False:
print("<"*100)
print("GOAL Relevance Input:\n")
print(goal_relevant_inp)
print()
print("GOAL Relevance Output:\n")
print(goal_relevant_response)
print()
print("Extracted GOAL Relevance Output:\n")
print(new_traj)
print(">"*100)
print()
if False:
print("|"*100)
print("OLD TRAJ:\n")
#print(state_history)
for x in state_history:
print(x)
print()
print("NEW TRAJ:\n")
for x in new_traj:
print(x)
print("|"*100)
if hs:
return {
"has_hs": True,
"hs": hs
}
else:
return {
"has_hs": False,
"hs": f"ERROR: Failed to generate hindsight trajectories"
}
def _generate_webshop_trajgoal_pairs(self, trajs):
"""
Generate trajectory-goal pairs for WebShop.
Similar to generate_trajgoal_pairs but adapted for WebShop format.
"""
traj_goal_pairs = []
current_goals = []
for step in trajs:
if 'reached_goals' in step and step['reached_goals']:
# Check for new goals reached at this step
new_goals = [g for g in step['reached_goals'] if g not in current_goals]
for goal in new_goals:
# Find where this goal was first reached
goal_start_idx = 0
for i, s in enumerate(trajs[:step['step']]):
if 'reached_goals' in s and goal in s['reached_goals']:
goal_start_idx = i
break
# Build trajectory up to current step
traj_text = ""
for s in trajs[goal_start_idx:step['step']]:
traj_text += f"Step {s['step']}: Action=\"{s['action']}\"; Observation=\"{s['observation']}\"\n"
if 'page_type' in s:
traj_text += f" Page: {s['page_type']}"
if 'products_viewed' in s and s['products_viewed']:
traj_text += f", Products viewed: {len(s['products_viewed'])}"
traj_text += "\n"
traj_goal_pairs.append({
'trajectory': traj_text,
'goal': goal
})
current_goals.extend(new_goals)
return traj_goal_pairs
def _extract_webshop_env_desc(self, original_task):
"""
Extract the WebShop task description (product search query).
"""
if isinstance(original_task, str):
# Extract the main search query/instruction
if "Instruction:" in original_task:
return original_task.split("Instruction:")[-1].split("\n")[0].strip()
return original_task.strip()
return "Shop for products online"
def _update_webshop_goal(self, trajectory, new_goal):
"""
Update ALL WebShop instructions in the trajectory with the hindsight goal.
WebShop repeats the instruction in every human message.
"""
updated_trajectory = []
for turn in trajectory:
updated_turn = copy.deepcopy(turn)
# Only update human messages that contain instructions
if turn.get("role") == "user" and "Instruction: [SEP]" in turn.get("content", ""):
content = turn["content"]
# Find where "Instruction: [SEP]" starts
inst_start = content.find("Instruction: [SEP]")
prefix = content[:inst_start + len("Instruction: [SEP]")]
# Find the next [SEP] after the goal
remaining = content[len(prefix):]
next_sep = remaining.find("[SEP]")
if next_sep != -1:
# Keep everything after the goal's [SEP]
suffix = remaining[next_sep:]
updated_turn["content"] = f"{prefix} {new_goal} {suffix}"
else:
# No trailing [SEP], just append the new goal
updated_turn["content"] = f"{prefix} {new_goal}"
updated_trajectory.append(updated_turn)
return updated_trajectory
class ShopBrew(BaseJuicer):
"""
WebShop-specific implementation of Lemonade relabeling.
Uses WebShop-specific prompts for e-commerce navigation tasks.
"""
def __init__(self, task_instruction, llama="llama-3-1-70b", first_only=False, api="internal"):
super().__init__(llama=llama, api=api)
self.task_instruction = task_instruction
self.first_only = first_only
# Use WebShop-specific prompts
self._relabel_prompt = WEBSHOP_GOAL_BREW_SYSTEM
self._query_prompt = WEBSHOP_QUERY_FREE_SYSTEM
self._relevant_prompt = WEBSHOP_RELEVANT_SYSTEM
def relabel_experience(self, state_history, obs, llm_out, original_task):
"""
Relabel WebShop trajectories with hindsight goals.
Adapted from Lemonade but handles WebShop's action/observation format.
"""
# Build trajectory with clean website observations
act_obs_traj = build_webshop_trajectory(obs, llm_out)
if not act_obs_traj:
return {
"has_hs": False,
"hs": "No valid action-observation pairs found"
}
# Call relabeling model with WebShop prompt
relabel_inp = [
{"role": "system", "content": self._relabel_prompt},
{"role": "user", "content": f"Here is the trajectory:\n{act_obs_traj}"},
]
chat_response = self._chat_completion(relabel_inp, "relabel")
try:
intentions = extract_shopbrew_dict_from_output(chat_response)
except Exception as exc:
print(f"ERROR: Failed to parse the relabel response:\nError: {exc}\n{chat_response}")
intentions = {}
if not intentions:
return {
"has_hs": False,
"hs": f"ERROR: Failed to parse the relabel response:\n{chat_response}"
}
# Extract purchase_success and selected directly from dict
purchase_success = intentions['purchase_success'] and intentions['query_satisfaction'] #intentions.get('purchase_success', False)
intentions_final = intentions['selected'] #intentions.get('selected', {})
if intentions_final is None:
print(f"Warning: failed to extract price {intentions}")
return {
"has_hs": False,
"hs": "Failed to extract price!!!"
}
intentions_final['price_limit'] = intentions_final.pop('price')
if not isinstance(intentions_final['price_limit'], int) and not isinstance(intentions_final['price_limit'], float):
print("WARNING: Failed extract price!!!")
return {
"has_hs": False,
"hs": "Failed to extract price!!!"
}
intentions_final['price_limit'] = map_to_random_milestone(intentions_final['price_limit'])
#print("purchase_success", purchase_success, type(purchase_success))
if not purchase_success:
print("WARNING: Failed to buy anything!!!")
return {
"has_hs": False,
"hs": "Failed to buy anything!!!"
}
# Build act_obs_intent_traj using the searched intention
#searched_intent = intentions['searched'] #intentions.get('searched', {})
# Convert to string representation for the trajectory
act_obs_intent_traj = build_webshop_trajectory(obs, llm_out)
#print("act_obs_intent_traj", act_obs_intent_traj)
goal_relevant_inp = [
{"role": "system", "content": self._relevant_prompt},
{
"role": "user",
"content": f"Shopping intention:{intentions_final}\n\nHere is the tracjtory:\n{act_obs_intent_traj}\n\nNow, please judge the relevance of actions at each step."
},
]
goal_relevant_response = self._chat_completion(goal_relevant_inp, "mask")
try:
# Use WebShop-specific JSON extraction to handle apostrophes and other issues
traj_goal_pairs_with_relevance = extract_webshop_json_objects(goal_relevant_response)
except Exception as exc:
print(f"ERROR: Failed to parse the RELEVANCE response:\nError: {exc}\n{goal_relevant_response}")
assert False
# Extract environment description (product search query)
query = f"{intentions['query']}, price lower than {intentions_final['price_limit']} dollars"
query_inp = [
{"role": "system", "content": self._query_prompt},
{
"role": "user",
"content": f"Here is the search query: {query}. Now transfrom it to three diverse and complete sentences."
},
]
query_raw = self._chat_completion(query_inp, "mask")
# Extract JSON list from response (use WebShop-specific version for measurement handling)
query = extract_webshop_query_list(query_raw)
if query is None:
print(f"ERROR: Failed to parse JSON list from query response: {query_raw}")
return {
"has_hs": False,
"hs": f"Failed to parse query!!!"
}
query = random.choice(query)#min(query, key=len).lower()
if False:
print("-"*100)
print("Input:\n")
print(relabel_inp)
print()
print("Output:\n")
print(chat_response)
print()
print("Original goal:\n")
print(state_history[0]['content'])
print()
print("Query:\n")
print(query)
print("-"*100)
print()
new_traj = copy.deepcopy(state_history)
new_traj = self._update_webshop_goal(new_traj, query.lower())
#print(f"DEBUG: Starting to add 'useful' key to trajectory with {len(new_traj)} messages")
#print(f"DEBUG: traj_goal_pairs_with_relevance has {len(traj_goal_pairs_with_relevance)} items")
for i, step in enumerate(new_traj):
if step['role'] == "assistant":
try:
relevance_idx = int(i//2)
#print(f"DEBUG: Processing assistant message at i={i}, using relevance_idx={relevance_idx}")
if relevance_idx >= len(traj_goal_pairs_with_relevance):
#print(
# f"ERROR: relevance_idx {relevance_idx} >= len(traj_goal_pairs_with_relevance) {len(traj_goal_pairs_with_relevance)}"
#)
#print(f"DEBUG: traj_goal_pairs_with_relevance = {traj_goal_pairs_with_relevance}")
return {
"has_hs": False,
"hs": f"ERROR: relevance_idx {relevance_idx} >= len(traj_goal_pairs_with_relevance) {len(traj_goal_pairs_with_relevance)}"
}
relevance_value = traj_goal_pairs_with_relevance[relevance_idx].get('relevance', '')
#print(f"DEBUG: relevance_value = {relevance_value}")
new_traj[i]['useful'] = True if 'yes' in relevance_value.lower() else False
#print(f"DEBUG: Set useful={new_traj[i]['useful']} for assistant message at i={i}")
except Exception as exc:
#print(f"ERROR: Failed to extract the relevance label:\nError: {exc}\n{(int(i//2), len(traj_goal_pairs_with_relevance))}")
#print(f"DEBUG: Full exception details: {exc}")
import traceback
traceback.print_exc()
assert False
#print(f"DEBUG: Finished adding 'useful' key. Checking results...")
#for i, msg in enumerate(new_traj):
# if msg['role'] == 'assistant':
# print(f" Assistant message {i}: has_useful={'useful' in msg}, useful={msg.get('useful', 'NOT SET')}")
hs = [new_traj]
if False:
print("<"*100)
print("GOAL Relevance Input:\n")
print(goal_relevant_inp)
print()
print("GOAL Relevance Output:\n")
print(goal_relevant_response)
print()
print("Extracted GOAL Relevance Output:\n")
print(new_traj)
print(">"*100)
print()
if False:
print("|"*100)
print("OLD TRAJ:\n")
#print(state_history)
for x in state_history:
print(x)
print()
print("NEW TRAJ:\n")
for x in new_traj:
print(x)
print("|"*100)
if hs:
return {
"has_hs": True,
"hs": hs
}
else:
return {
"has_hs": False,
"hs": f"ERROR: Failed to generate hindsight trajectories"
}
def _generate_webshop_trajgoal_pairs(self, trajs):
"""
Generate trajectory-goal pairs for WebShop.
Similar to generate_trajgoal_pairs but adapted for WebShop format.
"""
traj_goal_pairs = []
current_goals = []
for step in trajs:
if 'reached_goals' in step and step['reached_goals']:
# Check for new goals reached at this step
new_goals = [g for g in step['reached_goals'] if g not in current_goals]
for goal in new_goals:
# Find where this goal was first reached
goal_start_idx = 0
for i, s in enumerate(trajs[:step['step']]):
if 'reached_goals' in s and goal in s['reached_goals']:
goal_start_idx = i
break
# Build trajectory up to current step
traj_text = ""
for s in trajs[goal_start_idx:step['step']]:
traj_text += f"Step {s['step']}: Action=\"{s['action']}\"; Observation=\"{s['observation']}\"\n"
if 'page_type' in s:
traj_text += f" Page: {s['page_type']}"
if 'products_viewed' in s and s['products_viewed']:
traj_text += f", Products viewed: {len(s['products_viewed'])}"
traj_text += "\n"
traj_goal_pairs.append({
'trajectory': traj_text,
'goal': goal
})
current_goals.extend(new_goals)
return traj_goal_pairs
def _extract_webshop_env_desc(self, original_task):
"""
Extract the WebShop task description (product search query).
"""
if isinstance(original_task, str):
# Extract the main search query/instruction
if "Instruction:" in original_task:
return original_task.split("Instruction:")[-1].split("\n")[0].strip()
return original_task.strip()
return "Shop for products online"
def _update_webshop_goal(self, trajectory, new_goal):
"""
Update ALL WebShop instructions in the trajectory with the hindsight goal.
WebShop repeats the instruction in every human message.
"""
updated_trajectory = []
for turn in trajectory:
updated_turn = copy.deepcopy(turn)
# Only update human messages that contain instructions
if turn.get("role") == "user" and "Instruction: [SEP]" in turn.get("content", ""):
content = turn["content"]
# Find where "Instruction: [SEP]" starts
inst_start = content.find("Instruction: [SEP]")
prefix = content[:inst_start + len("Instruction: [SEP]")]
# Find the next [SEP] after the goal
remaining = content[len(prefix):]
next_sep = remaining.find("[SEP]")
if next_sep != -1:
# Keep everything after the goal's [SEP]
suffix = remaining[next_sep:]
updated_turn["content"] = f"{prefix} {new_goal} {suffix}"
else:
# No trailing [SEP], just append the new goal
updated_turn["content"] = f"{prefix} {new_goal}"
updated_trajectory.append(updated_turn)
return updated_trajectory