|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| GRPO training with OpenEnv's BrowserGym environment for VLMs (Vision Language Models).
|
|
|
| This script uses `environment_factory` with multimodal tool responses: each tool action
|
| returns a screenshot (PIL Image) alongside the accessibility tree text, allowing the VLM
|
| to see the page visually after each action.
|
|
|
| Setup:
|
| ```sh
|
| pip install "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env"
|
| ```
|
|
|
| Usage:
|
| ```sh
|
| # Without vLLM (default, 1 GPU)
|
| python examples/scripts/openenv/browsergym.py
|
|
|
| # With vLLM colocate (1 GPU, requires vLLM support for the model)
|
| python examples/scripts/openenv/browsergym.py --use-vllm
|
|
|
| # With vLLM server (2 GPUs)
|
| CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-2B --host 0.0.0.0 --port 8000
|
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym.py --use-vllm --vllm-mode server
|
| ```
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| from datetime import datetime
|
| from pathlib import Path
|
|
|
| import numpy as np
|
| from browsergym_env import BrowserGymAction, BrowserGymEnv
|
| from datasets import Dataset
|
| from PIL import Image
|
|
|
| from trl import GRPOConfig, GRPOTrainer
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(description="GRPO training with BrowserGym VLM environment.")
|
| parser.add_argument("--model-id", default="Qwen/Qwen3.5-2B")
|
| parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space")
|
| parser.add_argument("--dataset-prompt", default="Complete the web task successfully.")
|
| parser.add_argument("--dataset-size", type=int, default=1000)
|
| parser.add_argument("--max-steps", type=int, default=10)
|
| parser.add_argument("--max-completion-length", type=int, default=1024)
|
| parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.")
|
| parser.add_argument("--num-generations", type=int, default=4)
|
| parser.add_argument("--gradient-accumulation-steps", type=int, default=32)
|
| parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| parser.add_argument("--num-epochs", type=int, default=1)
|
| parser.add_argument("--logging-steps", type=int, default=1)
|
| parser.add_argument("--output-dir", default=None)
|
| parser.add_argument("--use-vllm", action="store_true", default=False, help="Enable vLLM for generation.")
|
| parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate")
|
| parser.add_argument("--vllm-server-url", default="http://localhost:8000")
|
| return parser.parse_args()
|
|
|
|
|
| def sanitize_name(name: str) -> str:
|
| return name.replace("/", "-")
|
|
|
|
|
| SYSTEM_PROMPT = """You control a web browser to complete tasks.
|
|
|
| The page structure shows elements as: [bid] element_type 'element_text'
|
| For example: [13] button 'Click Me!' means the element has bid='13'.
|
|
|
| You will see a screenshot of the page after each action. Use the visual information
|
| along with the page structure to decide your next action.
|
|
|
| Use the available tools to interact with the page:
|
| - click: Click an element by its bid
|
| - fill: Fill an input field with text
|
| - send_keys: Send keyboard input
|
| - scroll: Scroll the page
|
| - noop: Do nothing
|
|
|
| Complete the given task as efficiently as possible."""
|
|
|
|
|
| def reward_completion(completions, environments, **kwargs) -> list[float]:
|
| return [env.reward for env in environments]
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
|
|
| space_url = args.space_url
|
| max_steps = args.max_steps
|
| image_size = args.image_size
|
|
|
| dataset = Dataset.from_dict(
|
| {
|
| "prompt": [
|
| [
|
| {"role": "system", "content": SYSTEM_PROMPT},
|
| {"role": "user", "content": args.dataset_prompt},
|
| ]
|
| ]
|
| * args.dataset_size
|
| }
|
| )
|
|
|
| class BrowserGymVLMEnv:
|
| def __init__(self):
|
| self.client = BrowserGymEnv(base_url=space_url)
|
| self.reward = 0.0
|
| self.done = False
|
| self._step_count = 0
|
|
|
| def reset(self, **kwargs) -> str | None:
|
| self.reward = 0.0
|
| self.done = False
|
| self._step_count = 0
|
| result = self.client.reset()
|
| self.done = result.done
|
| return self._format_observation(result.observation)
|
|
|
| def click(self, bid: str) -> list:
|
| """Click an element on the page.
|
|
|
| Args:
|
| bid: The BrowserGym ID of the element to click.
|
|
|
| Returns:
|
| The updated page observation with screenshot.
|
| """
|
| return self._do_action(f"click('{bid}')")
|
|
|
| def fill(self, bid: str, text: str) -> list:
|
| """Fill an input field with text.
|
|
|
| Args:
|
| bid: The BrowserGym ID of the input field.
|
| text: The text to type into the field.
|
|
|
| Returns:
|
| The updated page observation with screenshot.
|
| """
|
| return self._do_action(f"fill('{bid}', '{text}')")
|
|
|
| def send_keys(self, text: str) -> list:
|
| """Send keyboard input to the page.
|
|
|
| Args:
|
| text: The keyboard input to send.
|
|
|
| Returns:
|
| The updated page observation with screenshot.
|
| """
|
| return self._do_action(f"send_keys('{text}')")
|
|
|
| def scroll(self, direction: str) -> list:
|
| """Scroll the page.
|
|
|
| Args:
|
| direction: Direction to scroll, either 'up' or 'down'.
|
|
|
| Returns:
|
| The updated page observation with screenshot.
|
| """
|
| return self._do_action(f"scroll('{direction}')")
|
|
|
| def noop(self) -> list:
|
| """Do nothing and observe the current page state.
|
|
|
| Returns:
|
| The current page observation with screenshot.
|
| """
|
| return self._do_action("noop()")
|
|
|
| def _do_action(self, action_str: str) -> list:
|
| if self.done:
|
| raise ValueError("Episode is done.")
|
|
|
| self._step_count += 1
|
| result = self.client.step(BrowserGymAction(action_str=action_str))
|
| observation = result.observation
|
| step_reward = float(result.reward or 0.0)
|
| self.done = result.done
|
|
|
| if self.done and step_reward > 0:
|
| self.reward = 1.0
|
| elif self.done:
|
| self.reward = 0.0
|
| else:
|
| self.reward = step_reward
|
|
|
| if self._step_count >= max_steps:
|
| self.done = True
|
|
|
| return self._format_observation_multimodal(observation)
|
|
|
| def _format_observation(self, observation) -> str:
|
| """Format initial observation as text (for reset, appended to prompt)."""
|
| parts = []
|
| if observation.goal:
|
| parts.append(f"Goal: {observation.goal}")
|
| if observation.axtree_txt:
|
| axtree = observation.axtree_txt
|
| if len(axtree) > 2000:
|
| axtree = axtree[:2000] + "..."
|
| parts.append(f"Page structure:\n{axtree}")
|
| return "\n\n".join(parts) if parts else "No observation available."
|
|
|
| def _format_observation_multimodal(self, observation) -> list:
|
| """Format observation as multimodal content blocks (screenshot + text)."""
|
| content = []
|
|
|
|
|
| if observation.screenshot is not None:
|
| screenshot_array = np.array(observation.screenshot, dtype=np.uint8)
|
| screenshot_image = Image.fromarray(screenshot_array)
|
| if image_size > 0:
|
| screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS)
|
| content.append({"type": "image", "image": screenshot_image})
|
|
|
|
|
| parts = []
|
| if observation.goal:
|
| parts.append(f"Goal: {observation.goal}")
|
| if observation.last_action_error and observation.error:
|
| parts.append(f"Error: {observation.error}")
|
| if observation.axtree_txt:
|
| axtree = observation.axtree_txt
|
| if len(axtree) > 2000:
|
| axtree = axtree[:2000] + "..."
|
| parts.append(f"Page structure:\n{axtree}")
|
| text = "\n\n".join(parts) if parts else "No observation available."
|
| content.append({"type": "text", "text": text})
|
|
|
| return content
|
|
|
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| default_output_dir = Path("outputs") / f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}-{timestamp}"
|
| output_dir = Path(args.output_dir or default_output_dir)
|
|
|
| trainer = GRPOTrainer(
|
| model=args.model_id,
|
| reward_funcs=reward_completion,
|
| train_dataset=dataset,
|
| args=GRPOConfig(
|
| use_vllm=args.use_vllm,
|
| vllm_mode=args.vllm_mode if args.use_vllm else "colocate",
|
| vllm_server_base_url=args.vllm_server_url if args.use_vllm and args.vllm_mode == "server" else None,
|
| output_dir=str(output_dir),
|
| num_train_epochs=args.num_epochs,
|
| learning_rate=args.learning_rate,
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| num_generations=args.num_generations,
|
| max_completion_length=args.max_completion_length,
|
| logging_steps=args.logging_steps,
|
| log_completions=True,
|
| report_to="trackio",
|
| trackio_space_id=f"browsergym-vlm-grpo-{sanitize_name(args.model_id)}",
|
| chat_template_kwargs={"enable_thinking": False},
|
| ),
|
| environment_factory=BrowserGymVLMEnv,
|
| )
|
| trainer.train()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|