trl-mcsd / examples /scripts /openenv /browsergym_llm.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl[vllm,peft]",
# "trackio",
# "kernels",
# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env",
# ]
# ///
"""
Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM for LLMs.
This script is optimized for text-only Language Models (LLMs). It uses the accessibility
tree text from BrowserGym, making it memory-efficient.
The environment runs on a Hugging Face Space by default.
Setup (Option A - Install from HF Space, recommended):
```sh
uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env
```
Setup (Option B - Clone OpenEnv repo, for development):
```sh
git clone https://github.com/meta-pytorch/OpenEnv.git
cd OpenEnv/envs/browsergym_env
uv pip install -e .
```
# Option 1: HF Spaces + Colocated vLLM (1 GPU required)
```sh
python examples/scripts/openenv/browsergym_llm.py --vllm-mode colocate
```
# Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
# Spin up vLLM server (Terminal 1)
```sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8001
```
# Run training (Terminal 2)
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym_llm.py --vllm-mode server --vllm-server-url http://localhost:8001
```
"""
from __future__ import annotations
import argparse
from datetime import datetime
from pathlib import Path
from browsergym_env import BrowserGymAction, BrowserGymEnv
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.")
parser.add_argument(
"--model-id",
default="Qwen/Qwen3-0.6B",
help="Model identifier passed to GRPOTrainer for fine-tuning.",
)
parser.add_argument(
"--space-url",
type=str,
default="https://openenv-browsergym-env.hf.space",
help="URL for the Hugging Face Space running the BrowserGym environment.",
)
parser.add_argument(
"--benchmark",
default="miniwob",
help="BrowserGym benchmark to use (miniwob, webarena, etc.).",
)
parser.add_argument(
"--task-name",
default="click-test",
help="Specific task within the benchmark (e.g., click-test, click-button).",
)
parser.add_argument(
"--dataset-prompt",
default="Complete the web task successfully.",
help="Prompt text used to seed the training dataset.",
)
parser.add_argument(
"--dataset-size",
type=int,
default=1000,
help="Number of entries to include in the synthetic training dataset.",
)
parser.add_argument(
"--max-steps",
type=int,
default=10,
help="Maximum number of steps per episode.",
)
parser.add_argument(
"--max-completion-length",
type=int,
default=1024,
help="Maximum completion length in tokens for tool-calling generation.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature used during rollout generation.",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="Top-k sampling parameter forwarded to vLLM.",
)
parser.add_argument(
"--top-p",
type=float,
default=None,
help="Optional top-p sampling parameter forwarded to vLLM.",
)
parser.add_argument(
"--learning-rate",
type=float,
default=5e-6,
help="Learning rate for GRPO training.",
)
parser.add_argument(
"--weight-decay",
type=float,
default=0.0,
help="Weight decay applied during optimization.",
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
default=32,
help="Gradient accumulation steps for GRPO training.",
)
parser.add_argument(
"--warmup-steps",
type=int,
default=10,
help="Warmup steps for the scheduler.",
)
parser.add_argument(
"--per-device-batch-size",
type=int,
default=1,
help="Per-device train batch size.",
)
parser.add_argument(
"--num-generations",
type=int,
default=4,
help="Number of rollout generations per dataset prompt.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1,
help="Number of training epochs.",
)
parser.add_argument(
"--save-interval",
type=int,
default=50,
help="Interval (in steps) between checkpoint saves.",
)
parser.add_argument(
"--save-total-limit",
type=int,
default=None,
help="Maximum number of checkpoints to keep.",
)
parser.add_argument(
"--output-dir",
default=None,
help="Directory where training outputs and checkpoints are stored.",
)
parser.add_argument(
"--run-name",
default=None,
help="Optional run name for logging systems.",
)
parser.add_argument(
"--project",
default=None,
help="Optional project identifier for logging systems.",
)
parser.add_argument(
"--vllm-mode",
choices=("colocate", "server"),
default="colocate",
help="vLLM execution mode: 'colocate' or 'server'.",
)
parser.add_argument(
"--vllm-server-url",
type=str,
default="http://localhost:8001",
help="URL for the vLLM server (only used when --vllm-mode=server).",
)
parser.add_argument(
"--logging-steps",
type=int,
default=1,
help="Frequency of logging steps for GRPO training.",
)
return parser.parse_args()
def sanitize_name(name: str) -> str:
return name.replace("/", "-")
# ---------------------------------------------------------------------------
# System Prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You control a web browser to complete tasks.
The page structure shows elements as: [bid] element_type 'element_text'
For example: [13] button 'Click Me!' means the element has bid='13'.
Use the available tools to interact with the page:
- click: Click an element by its bid
- fill: Fill an input field with text
- send_keys: Send keyboard input
- scroll: Scroll the page
- noop: Do nothing
Complete the given task as efficiently as possible."""
# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------
def reward_completion(environments, **kwargs) -> list[float]:
"""Reward for task completion."""
return [env.reward for env in environments]
# ---------------------------------------------------------------------------
# Main entrypoint
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
space_url = args.space_url
max_steps = args.max_steps
dataset = Dataset.from_dict(
{
"prompt": [
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": args.dataset_prompt},
]
]
* args.dataset_size
}
)
class BrowserGymLLMEnv:
def __init__(self):
self.client = BrowserGymEnv(base_url=space_url)
self.reward = 0.0
self._done = False
self._step_count = 0
def _ensure_large_max_size(self):
"""Raise WebSocket max message size for large observations (e.g. accessibility trees).
openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library
defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent.
"""
self.client.connect()
ws = self.client._ws
if ws is not None and hasattr(ws, "protocol"):
proto = ws.protocol
# websockets <16: max_size; websockets >=16: max_message_size
attr = "max_size" if hasattr(proto, "max_size") else "max_message_size"
if getattr(proto, attr) == 2**20:
setattr(proto, attr, 100 * 1024 * 1024)
def reset(self, **kwargs) -> str:
self.reward = 0.0
self._done = False
self._step_count = 0
self._ensure_large_max_size()
result = self.client.reset()
self._done = result.done
return self._format_observation(result.observation)
def click(self, bid: str) -> str:
"""Click an element on the page.
Args:
bid: The BrowserGym ID of the element to click.
Returns:
The updated page observation.
"""
return self._do_action(f"click({bid!r})")
def fill(self, bid: str, text: str) -> str:
"""Fill an input field with text.
Args:
bid: The BrowserGym ID of the input field.
text: The text to type into the field.
Returns:
The updated page observation.
"""
return self._do_action(f"fill({bid!r}, {text!r})")
def send_keys(self, text: str) -> str:
"""Send keyboard input to the page.
Args:
text: The keyboard input to send.
Returns:
The updated page observation.
"""
return self._do_action(f"send_keys({text!r})")
def scroll(self, direction: str) -> str:
"""Scroll the page.
Args:
direction: Direction to scroll, either 'up' or 'down'.
Returns:
The updated page observation.
"""
return self._do_action(f"scroll({direction!r})")
def noop(self) -> str:
"""Do nothing and observe the current page state.
Returns:
The current page observation.
"""
return self._do_action("noop()")
def _do_action(self, action_str: str) -> str:
if self._done:
raise ValueError("Episode is done.")
self._step_count += 1
result = self.client.step(BrowserGymAction(action_str=action_str))
observation = result.observation
step_reward = float(result.reward or 0.0)
self._done = result.done
# Reward shaping: binary success/failure on completion
if self._done and step_reward > 0:
self.reward = 1.0
elif self._done:
self.reward = 0.0
else:
self.reward = step_reward
# Enforce max steps
if self._step_count >= max_steps:
self._done = True
return self._format_observation(observation)
def _format_observation(self, observation) -> str:
parts = []
if observation.goal:
parts.append(f"Goal: {observation.goal}")
if observation.last_action_error and observation.error:
parts.append(f"Error: {observation.error}")
if observation.axtree_txt:
axtree = observation.axtree_txt
if len(axtree) > 2000:
axtree = axtree[:2000] + "..."
parts.append(f"Page structure:\n{axtree}")
return "\n\n".join(parts) if parts else "No observation available."
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}"
output_dir = Path(args.output_dir or default_output_dir)
grpo_config = GRPOConfig(
use_vllm=True,
vllm_mode=args.vllm_mode,
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
vllm_gpu_memory_utilization=0.4,
output_dir=str(output_dir),
num_train_epochs=args.num_epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
gradient_accumulation_steps=args.gradient_accumulation_steps,
per_device_train_batch_size=args.per_device_batch_size,
warmup_steps=args.warmup_steps,
num_generations=args.num_generations,
generation_batch_size=args.num_generations,
max_completion_length=args.max_completion_length,
logging_steps=args.logging_steps,
report_to="trackio",
trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}",
save_strategy="steps",
save_steps=args.save_interval,
save_total_limit=args.save_total_limit,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
chat_template_kwargs={"enable_thinking": False},
)
grpo_config.run_name = args.run_name or f"run-{timestamp}"
grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
trainer = GRPOTrainer(
model=args.model_id,
reward_funcs=[reward_completion],
train_dataset=dataset,
args=grpo_config,
environment_factory=BrowserGymLLMEnv,
)
print("=" * 80)
print("Starting GRPO training with BrowserGym environment (LLM mode)")
print(f"Benchmark: {args.benchmark}")
print(f"Task: {args.task_name}")
print(f"Model: {args.model_id}")
print("Mode: LLM (text-only, using accessibility tree)")
print(f"Using {args.num_generations} rollouts per dataset prompt")
print(f"Output directory: {output_dir}")
print("=" * 80)
trainer.train()
if __name__ == "__main__":
main()