trl-mcsd / examples /scripts /openenv /browsergym.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl[vllm,peft]",
# "trackio",
# "kernels",
# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env",
# ]
# ///
"""
GRPO training with OpenEnv's BrowserGym environment for VLMs (Vision Language Models).
This script uses `environment_factory` with multimodal tool responses: each tool action
returns a screenshot (PIL Image) alongside the accessibility tree text, allowing the VLM
to see the page visually after each action.
Setup:
```sh
pip install "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env"
```
Usage:
```sh
# Without vLLM (default, 1 GPU)
python examples/scripts/openenv/browsergym.py
# With vLLM colocate (1 GPU, requires vLLM support for the model)
python examples/scripts/openenv/browsergym.py --use-vllm
# With vLLM server (2 GPUs)
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-2B --host 0.0.0.0 --port 8000
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym.py --use-vllm --vllm-mode server
```
"""
from __future__ import annotations
import argparse
from datetime import datetime
from pathlib import Path
import numpy as np
from browsergym_env import BrowserGymAction, BrowserGymEnv
from datasets import Dataset
from PIL import Image
from trl import GRPOConfig, GRPOTrainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="GRPO training with BrowserGym VLM environment.")
parser.add_argument("--model-id", default="Qwen/Qwen3.5-2B")
parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space")
parser.add_argument("--dataset-prompt", default="Complete the web task successfully.")
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--max-steps", type=int, default=10)
parser.add_argument("--max-completion-length", type=int, default=1024)
parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.")
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--gradient-accumulation-steps", type=int, default=32)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument("--logging-steps", type=int, default=1)
parser.add_argument("--output-dir", default=None)
parser.add_argument("--use-vllm", action="store_true", default=False, help="Enable vLLM for generation.")
parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate")
parser.add_argument("--vllm-server-url", default="http://localhost:8000")
return parser.parse_args()
def sanitize_name(name: str) -> str:
return name.replace("/", "-")
SYSTEM_PROMPT = """You control a web browser to complete tasks.
The page structure shows elements as: [bid] element_type 'element_text'
For example: [13] button 'Click Me!' means the element has bid='13'.
You will see a screenshot of the page after each action. Use the visual information
along with the page structure to decide your next action.
Use the available tools to interact with the page:
- click: Click an element by its bid
- fill: Fill an input field with text
- send_keys: Send keyboard input
- scroll: Scroll the page
- noop: Do nothing
Complete the given task as efficiently as possible."""
def reward_completion(completions, environments, **kwargs) -> list[float]:
return [env.reward for env in environments]
def main() -> None:
args = parse_args()
space_url = args.space_url
max_steps = args.max_steps
image_size = args.image_size
dataset = Dataset.from_dict(
{
"prompt": [
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": args.dataset_prompt},
]
]
* args.dataset_size
}
)
class BrowserGymVLMEnv:
def __init__(self):
self.client = BrowserGymEnv(base_url=space_url)
self.reward = 0.0
self.done = False
self._step_count = 0
def reset(self, **kwargs) -> str | None:
self.reward = 0.0
self.done = False
self._step_count = 0
result = self.client.reset()
self.done = result.done
return self._format_observation(result.observation)
def click(self, bid: str) -> list:
"""Click an element on the page.
Args:
bid: The BrowserGym ID of the element to click.
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"click('{bid}')")
def fill(self, bid: str, text: str) -> list:
"""Fill an input field with text.
Args:
bid: The BrowserGym ID of the input field.
text: The text to type into the field.
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"fill('{bid}', '{text}')")
def send_keys(self, text: str) -> list:
"""Send keyboard input to the page.
Args:
text: The keyboard input to send.
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"send_keys('{text}')")
def scroll(self, direction: str) -> list:
"""Scroll the page.
Args:
direction: Direction to scroll, either 'up' or 'down'.
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"scroll('{direction}')")
def noop(self) -> list:
"""Do nothing and observe the current page state.
Returns:
The current page observation with screenshot.
"""
return self._do_action("noop()")
def _do_action(self, action_str: str) -> list:
if self.done:
raise ValueError("Episode is done.")
self._step_count += 1
result = self.client.step(BrowserGymAction(action_str=action_str))
observation = result.observation
step_reward = float(result.reward or 0.0)
self.done = result.done
if self.done and step_reward > 0:
self.reward = 1.0
elif self.done:
self.reward = 0.0
else:
self.reward = step_reward
if self._step_count >= max_steps:
self.done = True
return self._format_observation_multimodal(observation)
def _format_observation(self, observation) -> str:
"""Format initial observation as text (for reset, appended to prompt)."""
parts = []
if observation.goal:
parts.append(f"Goal: {observation.goal}")
if observation.axtree_txt:
axtree = observation.axtree_txt
if len(axtree) > 2000:
axtree = axtree[:2000] + "..."
parts.append(f"Page structure:\n{axtree}")
return "\n\n".join(parts) if parts else "No observation available."
def _format_observation_multimodal(self, observation) -> list:
"""Format observation as multimodal content blocks (screenshot + text)."""
content = []
# Add screenshot if available
if observation.screenshot is not None:
screenshot_array = np.array(observation.screenshot, dtype=np.uint8)
screenshot_image = Image.fromarray(screenshot_array)
if image_size > 0:
screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS)
content.append({"type": "image", "image": screenshot_image})
# Add text observation
parts = []
if observation.goal:
parts.append(f"Goal: {observation.goal}")
if observation.last_action_error and observation.error:
parts.append(f"Error: {observation.error}")
if observation.axtree_txt:
axtree = observation.axtree_txt
if len(axtree) > 2000:
axtree = axtree[:2000] + "..."
parts.append(f"Page structure:\n{axtree}")
text = "\n\n".join(parts) if parts else "No observation available."
content.append({"type": "text", "text": text})
return content
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
default_output_dir = Path("outputs") / f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}-{timestamp}"
output_dir = Path(args.output_dir or default_output_dir)
trainer = GRPOTrainer(
model=args.model_id,
reward_funcs=reward_completion,
train_dataset=dataset,
args=GRPOConfig(
use_vllm=args.use_vllm,
vllm_mode=args.vllm_mode if args.use_vllm else "colocate",
vllm_server_base_url=args.vllm_server_url if args.use_vllm and args.vllm_mode == "server" else None,
output_dir=str(output_dir),
num_train_epochs=args.num_epochs,
learning_rate=args.learning_rate,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_generations=args.num_generations,
max_completion_length=args.max_completion_length,
logging_steps=args.logging_steps,
log_completions=True,
report_to="trackio",
trackio_space_id=f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}",
chat_template_kwargs={"enable_thinking": False},
),
environment_factory=BrowserGymVLMEnv,
)
trainer.train()
if __name__ == "__main__":
main()