|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM for LLMs.
|
|
|
| This script is optimized for text-only Language Models (LLMs). It uses the accessibility
|
| tree text from BrowserGym, making it memory-efficient.
|
|
|
| The environment runs on a Hugging Face Space by default.
|
|
|
| Setup (Option A - Install from HF Space, recommended):
|
|
|
| ```sh
|
| uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env
|
| ```
|
|
|
| Setup (Option B - Clone OpenEnv repo, for development):
|
|
|
| ```sh
|
| git clone https://github.com/meta-pytorch/OpenEnv.git
|
| cd OpenEnv/envs/browsergym_env
|
| uv pip install -e .
|
| ```
|
|
|
| # Option 1: HF Spaces + Colocated vLLM (1 GPU required)
|
| ```sh
|
| python examples/scripts/openenv/browsergym_llm.py --vllm-mode colocate
|
| ```
|
|
|
| # Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
|
|
|
| # Spin up vLLM server (Terminal 1)
|
| ```sh
|
| CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8001
|
| ```
|
|
|
| # Run training (Terminal 2)
|
| ```sh
|
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym_llm.py --vllm-mode server --vllm-server-url http://localhost:8001
|
| ```
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| from datetime import datetime
|
| from pathlib import Path
|
|
|
| from browsergym_env import BrowserGymAction, BrowserGymEnv
|
| from datasets import Dataset
|
|
|
| from trl import GRPOConfig, GRPOTrainer
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.")
|
| parser.add_argument(
|
| "--model-id",
|
| default="Qwen/Qwen3-0.6B",
|
| help="Model identifier passed to GRPOTrainer for fine-tuning.",
|
| )
|
| parser.add_argument(
|
| "--space-url",
|
| type=str,
|
| default="https://openenv-browsergym-env.hf.space",
|
| help="URL for the Hugging Face Space running the BrowserGym environment.",
|
| )
|
| parser.add_argument(
|
| "--benchmark",
|
| default="miniwob",
|
| help="BrowserGym benchmark to use (miniwob, webarena, etc.).",
|
| )
|
| parser.add_argument(
|
| "--task-name",
|
| default="click-test",
|
| help="Specific task within the benchmark (e.g., click-test, click-button).",
|
| )
|
| parser.add_argument(
|
| "--dataset-prompt",
|
| default="Complete the web task successfully.",
|
| help="Prompt text used to seed the training dataset.",
|
| )
|
| parser.add_argument(
|
| "--dataset-size",
|
| type=int,
|
| default=1000,
|
| help="Number of entries to include in the synthetic training dataset.",
|
| )
|
| parser.add_argument(
|
| "--max-steps",
|
| type=int,
|
| default=10,
|
| help="Maximum number of steps per episode.",
|
| )
|
| parser.add_argument(
|
| "--max-completion-length",
|
| type=int,
|
| default=1024,
|
| help="Maximum completion length in tokens for tool-calling generation.",
|
| )
|
| parser.add_argument(
|
| "--temperature",
|
| type=float,
|
| default=0.7,
|
| help="Sampling temperature used during rollout generation.",
|
| )
|
| parser.add_argument(
|
| "--top-k",
|
| type=int,
|
| default=50,
|
| help="Top-k sampling parameter forwarded to vLLM.",
|
| )
|
| parser.add_argument(
|
| "--top-p",
|
| type=float,
|
| default=None,
|
| help="Optional top-p sampling parameter forwarded to vLLM.",
|
| )
|
| parser.add_argument(
|
| "--learning-rate",
|
| type=float,
|
| default=5e-6,
|
| help="Learning rate for GRPO training.",
|
| )
|
| parser.add_argument(
|
| "--weight-decay",
|
| type=float,
|
| default=0.0,
|
| help="Weight decay applied during optimization.",
|
| )
|
| parser.add_argument(
|
| "--gradient-accumulation-steps",
|
| type=int,
|
| default=32,
|
| help="Gradient accumulation steps for GRPO training.",
|
| )
|
| parser.add_argument(
|
| "--warmup-steps",
|
| type=int,
|
| default=10,
|
| help="Warmup steps for the scheduler.",
|
| )
|
| parser.add_argument(
|
| "--per-device-batch-size",
|
| type=int,
|
| default=1,
|
| help="Per-device train batch size.",
|
| )
|
| parser.add_argument(
|
| "--num-generations",
|
| type=int,
|
| default=4,
|
| help="Number of rollout generations per dataset prompt.",
|
| )
|
| parser.add_argument(
|
| "--num-epochs",
|
| type=int,
|
| default=1,
|
| help="Number of training epochs.",
|
| )
|
| parser.add_argument(
|
| "--save-interval",
|
| type=int,
|
| default=50,
|
| help="Interval (in steps) between checkpoint saves.",
|
| )
|
| parser.add_argument(
|
| "--save-total-limit",
|
| type=int,
|
| default=None,
|
| help="Maximum number of checkpoints to keep.",
|
| )
|
| parser.add_argument(
|
| "--output-dir",
|
| default=None,
|
| help="Directory where training outputs and checkpoints are stored.",
|
| )
|
| parser.add_argument(
|
| "--run-name",
|
| default=None,
|
| help="Optional run name for logging systems.",
|
| )
|
| parser.add_argument(
|
| "--project",
|
| default=None,
|
| help="Optional project identifier for logging systems.",
|
| )
|
| parser.add_argument(
|
| "--vllm-mode",
|
| choices=("colocate", "server"),
|
| default="colocate",
|
| help="vLLM execution mode: 'colocate' or 'server'.",
|
| )
|
| parser.add_argument(
|
| "--vllm-server-url",
|
| type=str,
|
| default="http://localhost:8001",
|
| help="URL for the vLLM server (only used when --vllm-mode=server).",
|
| )
|
| parser.add_argument(
|
| "--logging-steps",
|
| type=int,
|
| default=1,
|
| help="Frequency of logging steps for GRPO training.",
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| def sanitize_name(name: str) -> str:
|
| return name.replace("/", "-")
|
|
|
|
|
|
|
|
|
|
|
|
|
| SYSTEM_PROMPT = """You control a web browser to complete tasks.
|
|
|
| The page structure shows elements as: [bid] element_type 'element_text'
|
| For example: [13] button 'Click Me!' means the element has bid='13'.
|
|
|
| Use the available tools to interact with the page:
|
| - click: Click an element by its bid
|
| - fill: Fill an input field with text
|
| - send_keys: Send keyboard input
|
| - scroll: Scroll the page
|
| - noop: Do nothing
|
|
|
| Complete the given task as efficiently as possible."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def reward_completion(environments, **kwargs) -> list[float]:
|
| """Reward for task completion."""
|
| return [env.reward for env in environments]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
|
|
| space_url = args.space_url
|
| max_steps = args.max_steps
|
|
|
| dataset = Dataset.from_dict(
|
| {
|
| "prompt": [
|
| [
|
| {"role": "system", "content": SYSTEM_PROMPT},
|
| {"role": "user", "content": args.dataset_prompt},
|
| ]
|
| ]
|
| * args.dataset_size
|
| }
|
| )
|
|
|
| class BrowserGymLLMEnv:
|
| def __init__(self):
|
| self.client = BrowserGymEnv(base_url=space_url)
|
| self.reward = 0.0
|
| self._done = False
|
| self._step_count = 0
|
|
|
| def _ensure_large_max_size(self):
|
| """Raise WebSocket max message size for large observations (e.g. accessibility trees).
|
|
|
| openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library
|
| defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent.
|
| """
|
| self.client.connect()
|
| ws = self.client._ws
|
| if ws is not None and hasattr(ws, "protocol"):
|
| proto = ws.protocol
|
|
|
| attr = "max_size" if hasattr(proto, "max_size") else "max_message_size"
|
| if getattr(proto, attr) == 2**20:
|
| setattr(proto, attr, 100 * 1024 * 1024)
|
|
|
| def reset(self, **kwargs) -> str:
|
| self.reward = 0.0
|
| self._done = False
|
| self._step_count = 0
|
| self._ensure_large_max_size()
|
| result = self.client.reset()
|
| self._done = result.done
|
| return self._format_observation(result.observation)
|
|
|
| def click(self, bid: str) -> str:
|
| """Click an element on the page.
|
|
|
| Args:
|
| bid: The BrowserGym ID of the element to click.
|
|
|
| Returns:
|
| The updated page observation.
|
| """
|
| return self._do_action(f"click({bid!r})")
|
|
|
| def fill(self, bid: str, text: str) -> str:
|
| """Fill an input field with text.
|
|
|
| Args:
|
| bid: The BrowserGym ID of the input field.
|
| text: The text to type into the field.
|
|
|
| Returns:
|
| The updated page observation.
|
| """
|
| return self._do_action(f"fill({bid!r}, {text!r})")
|
|
|
| def send_keys(self, text: str) -> str:
|
| """Send keyboard input to the page.
|
|
|
| Args:
|
| text: The keyboard input to send.
|
|
|
| Returns:
|
| The updated page observation.
|
| """
|
| return self._do_action(f"send_keys({text!r})")
|
|
|
| def scroll(self, direction: str) -> str:
|
| """Scroll the page.
|
|
|
| Args:
|
| direction: Direction to scroll, either 'up' or 'down'.
|
|
|
| Returns:
|
| The updated page observation.
|
| """
|
| return self._do_action(f"scroll({direction!r})")
|
|
|
| def noop(self) -> str:
|
| """Do nothing and observe the current page state.
|
|
|
| Returns:
|
| The current page observation.
|
| """
|
| return self._do_action("noop()")
|
|
|
| def _do_action(self, action_str: str) -> str:
|
| if self._done:
|
| raise ValueError("Episode is done.")
|
|
|
| self._step_count += 1
|
| result = self.client.step(BrowserGymAction(action_str=action_str))
|
| observation = result.observation
|
| step_reward = float(result.reward or 0.0)
|
| self._done = result.done
|
|
|
|
|
| if self._done and step_reward > 0:
|
| self.reward = 1.0
|
| elif self._done:
|
| self.reward = 0.0
|
| else:
|
| self.reward = step_reward
|
|
|
|
|
| if self._step_count >= max_steps:
|
| self._done = True
|
|
|
| return self._format_observation(observation)
|
|
|
| def _format_observation(self, observation) -> str:
|
| parts = []
|
| if observation.goal:
|
| parts.append(f"Goal: {observation.goal}")
|
| if observation.last_action_error and observation.error:
|
| parts.append(f"Error: {observation.error}")
|
| if observation.axtree_txt:
|
| axtree = observation.axtree_txt
|
| if len(axtree) > 2000:
|
| axtree = axtree[:2000] + "..."
|
| parts.append(f"Page structure:\n{axtree}")
|
| return "\n\n".join(parts) if parts else "No observation available."
|
|
|
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}"
|
| output_dir = Path(args.output_dir or default_output_dir)
|
|
|
| grpo_config = GRPOConfig(
|
| use_vllm=True,
|
| vllm_mode=args.vllm_mode,
|
| vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
|
| vllm_gpu_memory_utilization=0.4,
|
| output_dir=str(output_dir),
|
| num_train_epochs=args.num_epochs,
|
| learning_rate=args.learning_rate,
|
| weight_decay=args.weight_decay,
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| per_device_train_batch_size=args.per_device_batch_size,
|
| warmup_steps=args.warmup_steps,
|
| num_generations=args.num_generations,
|
| generation_batch_size=args.num_generations,
|
| max_completion_length=args.max_completion_length,
|
| logging_steps=args.logging_steps,
|
| report_to="trackio",
|
| trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}",
|
| save_strategy="steps",
|
| save_steps=args.save_interval,
|
| save_total_limit=args.save_total_limit,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| top_p=args.top_p,
|
| chat_template_kwargs={"enable_thinking": False},
|
| )
|
|
|
| grpo_config.run_name = args.run_name or f"run-{timestamp}"
|
| grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
|
|
|
| trainer = GRPOTrainer(
|
| model=args.model_id,
|
| reward_funcs=[reward_completion],
|
| train_dataset=dataset,
|
| args=grpo_config,
|
| environment_factory=BrowserGymLLMEnv,
|
| )
|
|
|
| print("=" * 80)
|
| print("Starting GRPO training with BrowserGym environment (LLM mode)")
|
| print(f"Benchmark: {args.benchmark}")
|
| print(f"Task: {args.task_name}")
|
| print(f"Model: {args.model_id}")
|
| print("Mode: LLM (text-only, using accessibility tree)")
|
| print(f"Using {args.num_generations} rollouts per dataset prompt")
|
| print(f"Output directory: {output_dir}")
|
| print("=" * 80)
|
|
|
| trainer.train()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|