🌐 Browser Control Agent - Gemma 3 1B

A fine-tuned Gemma-3-1B-IT model for autonomous browser control and web automation tasks. This model can understand web page contexts and generate appropriate browser actions.

🎯 Model Description

This model is trained to act as a browser automation agent, capable of understanding user goals, analyzing web page elements, and generating precise browser actions to accomplish tasks.

Key Capabilities

  • Element Interaction: Click buttons, links, and interactive elements
  • Form Filling: Input text into form fields
  • Navigation: Scroll pages and navigate to URLs
  • Hover Actions: Trigger hover-based UI elements

🔧 Technical Details

Base Model

Property Value
Base Model google/gemma-3-1b-it
Architecture Gemma 3 Decoder-Only Transformer
Parameters 1 Billion
Context Length 8K tokens

Fine-Tuning Configuration

Property Value
Method LoRA (Low-Rank Adaptation)
LoRA Rank 32
LoRA Alpha 32.0
Target Modules q_einsum, kv_einsum, gate_proj, down_proj, up_proj, attn_vec_einsum

Training Infrastructure

Property Value
Hardware 8x Google TPU v5e
Total HBM 126 GB
Framework JAX + Flax NNX
Library Google Tunix
LoRA Library Qwix
Precision bfloat16
Sharding FSDP (2) × TP (4)

Training Process

Phase 1: Supervised Fine-Tuning (SFT)

Property Value
Steps 300
Batch Size 2
Learning Rate 5e-5
Scheduler Warmup Cosine Decay
Initial Loss 4.0987
Final Loss 0.0446
Loss Reduction 99%

Phase 2: Group Relative Policy Optimization (GRPO)

Property Value
Algorithm GRPO (Reinforcement Learning)
Steps 50-150
Generations per Prompt 1-2
Temperature 0.7
Reward Function Custom action-matching scorer

📋 Supported Actions

The model generates actions in the following format: <action>ACTION_TYPE(args)</action>

Action Syntax Description Example
Click click(bid) Click on an element by browser ID <action>click('1')</action>
Fill fill(bid, text) Fill an input field with text <action>fill('2', 'user@email.com')</action>
Scroll scroll(dx, dy) Scroll the page by pixels <action>scroll(0, 500)</action>
Hover hover(bid) Hover over an element <action>hover('3')</action>
Navigate goto(url) Navigate to a URL <action>goto('https://example.com')</action>

💻 Usage

Prompt Format

<start_of_turn>user
You are a browser automation agent. Your task is to interact with web pages to complete goals.

Available actions:
- click(bid): Click element with browser ID
- fill(bid, text): Fill input field with text
- scroll(dx, dy): Scroll by pixels
- hover(bid): Hover over element
- goto(url): Navigate to URL

Output your chosen action between <action> and </action> tags.

URL: https://example.com
Goal: Click the login button

Available Elements:
[1] <button id='login'>Login</button>
[2] <a href='/signup'>Sign up</a><end_of_turn>
<start_of_turn>model

Expected Output

<action>click('1')</action>

🚀 Quick Start - Download & Use

Step 1: Install Dependencies

pip install transformers torch huggingface_hub playwright
playwright install chromium

Step 2: Download Model from HuggingFace

from huggingface_hub import snapshot_download

# Download the full merged model
model_path = snapshot_download(
    repo_id="batuhanozkose/browser-control-gemma"
)
print(f"Model downloaded to: {model_path}")

Step 3: Load Model for Inference

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

# Load the fine-tuned model
# Option A: If using PyTorch converted weights
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Option B: For JAX/Tunix format (original)
# See JAX inference section below

🌐 Full Browser Automation with Playwright/Chromium

Here's a complete example of using this model to automate a browser:

import asyncio
import re
from playwright.async_api import async_playwright
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ========== MODEL SETUP ==========
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
model = AutoModelForCausalLM.from_pretrained(
    "batuhanozkose/browser-control-gemma",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# ========== PROMPT TEMPLATE ==========
SYSTEM_PROMPT = """You are a browser automation agent. Your task is to interact with web pages to complete goals.

Available actions:
- click(bid): Click element with browser ID
- fill(bid, text): Fill input field with text  
- scroll(dx, dy): Scroll by pixels
- hover(bid): Hover over element
- goto(url): Navigate to URL

Output your chosen action between <action> and </action> tags."""

def create_prompt(url: str, goal: str, elements: str) -> str:
    return f"""<start_of_turn>user
{SYSTEM_PROMPT}

URL: {url}
Goal: {goal}

Available Elements:
{elements}<end_of_turn>
<start_of_turn>model
"""

# ========== ELEMENT EXTRACTION ==========
async def get_page_elements(page) -> str:
    """Extract interactive elements from page with browser IDs."""
    elements = await page.evaluate("""
        () => {
            const interactiveElements = document.querySelectorAll(
                'button, a, input, select, textarea, [onclick], [role="button"]'
            );
            let result = [];
            let bid = 1;
            interactiveElements.forEach(el => {
                if (el.offsetParent !== null) {  // visible elements only
                    const tag = el.tagName.toLowerCase();
                    const text = el.textContent?.trim().slice(0, 50) || '';
                    const type = el.type || '';
                    const placeholder = el.placeholder || '';
                    
                    // Set browser ID as data attribute
                    el.setAttribute('data-bid', bid);
                    
                    let desc = `[${bid}] <${tag}`;
                    if (el.id) desc += ` id='${el.id}'`;
                    if (type) desc += ` type='${type}'`;
                    if (placeholder) desc += ` placeholder='${placeholder}'`;
                    desc += `>${text}</${tag}>`;
                    
                    result.push(desc);
                    bid++;
                }
            });
            return result.slice(0, 20).join('\\n');  // Limit to 20 elements
        }
    """)
    return elements

# ========== ACTION PARSER ==========
def parse_action(response: str) -> tuple:
    """Parse model output to extract action."""
    match = re.search(r'<action>(.*?)</action>', response, re.DOTALL)
    if not match:
        return None, None
    
    action = match.group(1).strip()
    
    # Parse action type and arguments
    if action.startswith('click'):
        bid = re.search(r"click\(['\"]?(\d+)['\"]?\)", action)
        return 'click', bid.group(1) if bid else None
    elif action.startswith('fill'):
        match = re.search(r"fill\(['\"]?(\d+)['\"]?,\s*['\"](.+?)['\"]\)", action)
        if match:
            return 'fill', (match.group(1), match.group(2))
    elif action.startswith('scroll'):
        match = re.search(r"scroll\((-?\d+),\s*(-?\d+)\)", action)
        if match:
            return 'scroll', (int(match.group(1)), int(match.group(2)))
    elif action.startswith('hover'):
        bid = re.search(r"hover\(['\"]?(\d+)['\"]?\)", action)
        return 'hover', bid.group(1) if bid else None
    elif action.startswith('goto'):
        url = re.search(r"goto\(['\"](.+?)['\"]\)", action)
        return 'goto', url.group(1) if url else None
    
    return None, None

# ========== ACTION EXECUTOR ==========
async def execute_action(page, action_type: str, args):
    """Execute the parsed action on the browser."""
    if action_type == 'click':
        element = await page.query_selector(f'[data-bid="{args}"]')
        if element:
            await element.click()
            print(f"✅ Clicked element {args}")
    
    elif action_type == 'fill':
        bid, text = args
        element = await page.query_selector(f'[data-bid="{bid}"]')
        if element:
            await element.fill(text)
            print(f"✅ Filled element {bid} with '{text}'")
    
    elif action_type == 'scroll':
        dx, dy = args
        await page.evaluate(f"window.scrollBy({dx}, {dy})")
        print(f"✅ Scrolled by ({dx}, {dy})")
    
    elif action_type == 'hover':
        element = await page.query_selector(f'[data-bid="{args}"]')
        if element:
            await element.hover()
            print(f"✅ Hovered element {args}")
    
    elif action_type == 'goto':
        await page.goto(args)
        print(f"✅ Navigated to {args}")

# ========== INFERENCE ==========
def get_model_action(prompt: str) -> str:
    """Get action from model."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):]

# ========== MAIN AGENT LOOP ==========
async def run_browser_agent(goal: str, start_url: str, max_steps: int = 10):
    """Run the browser automation agent."""
    async with async_playwright() as p:
        browser = await p.chromium.launch(headless=False)  # Set True for headless
        page = await browser.new_page()
        
        await page.goto(start_url)
        print(f"🌐 Starting at: {start_url}")
        print(f"🎯 Goal: {goal}\n")
        
        for step in range(max_steps):
            print(f"--- Step {step + 1} ---")
            
            # Get current page state
            url = page.url
            elements = await get_page_elements(page)
            
            # Create prompt and get model response
            prompt = create_prompt(url, goal, elements)
            response = get_model_action(prompt)
            print(f"Model output: {response}")
            
            # Parse and execute action
            action_type, args = parse_action(response)
            if action_type:
                await execute_action(page, action_type, args)
                await asyncio.sleep(1)  # Wait for page to update
            else:
                print("⚠️ Could not parse action, stopping.")
                break
        
        await browser.close()

# ========== RUN EXAMPLE ==========
if __name__ == "__main__":
    asyncio.run(run_browser_agent(
        goal="Search for 'machine learning' and click the first result",
        start_url="https://www.google.com"
    ))

🔧 JAX/Tunix Inference (Original Format)

If you want to use the model in its original JAX/Tunix format:

import jax
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors_lib

# Setup mesh for TPU/GPU
MESH = [(2, 4), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * 2)

# Load model
model_config = gemma_lib.ModelConfig.gemma3_1b()
with mesh:
    model = params_safetensors_lib.create_model_from_safe_tensors(
        "path/to/downloaded/model", 
        model_config, 
        mesh
    )

# Inference
def generate(prompt: str, max_tokens: int = 50):
    inputs = tokenizer(prompt, return_tensors="np")
    input_ids = jnp.array(inputs["input_ids"])
    # ... generation loop

📦 API Server (Optional)

Wrap the model as a REST API:

from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()

class TaskRequest(BaseModel):
    goal: str
    url: str
    elements: str

@app.post("/predict")
async def predict(request: TaskRequest):
    prompt = create_prompt(request.url, request.goal, request.elements)
    response = get_model_action(prompt)
    action_type, args = parse_action(response)
    return {"action": action_type, "args": args, "raw": response}

# Run: uvicorn server:app --host 0.0.0.0 --port 8000

📊 Performance

Training Metrics

  • SFT Loss Reduction: 4.09 → 0.04 (99% reduction)
  • Action Format Accuracy: High (model consistently produces valid action syntax)

Benchmark Results

Evaluation on synthetic browser control tasks

Metric Score
Exact Action Match ~85-90%
Correct Action Type ~95%
Valid Format ~98%

🏗️ Training Data

The model was trained on synthetic browser interaction data including:

  • Click actions on various UI elements (buttons, links, checkboxes)
  • Form filling scenarios (email, password, search, address fields)
  • Scroll interactions (page navigation)
  • Hover actions (dropdown menus, tooltips)
  • Navigation commands

Data Generation

  • Samples: 1000 synthetic examples
  • Train/Val Split: 90/10
  • Diversity: Multiple action types, element types, and goal formulations

🔬 Methodology

Why Tunix?

Tunix is Google's lightweight LLM post-training library, providing:

  • Native JAX/TPU optimization
  • Efficient LoRA implementation via Qwix
  • Built-in GRPO support for reinforcement learning
  • Seamless multi-device sharding

Why GRPO?

Group Relative Policy Optimization improves upon SFT by:

  • Exploring multiple action candidates
  • Learning from comparative rewards
  • Improving robustness to input variations

⚠️ Limitations

  • Trained on synthetic data only (not real web interactions)
  • Limited to predefined action types
  • May require additional fine-tuning for specific websites
  • Browser element ID format must match training format

📚 References

📜 License

This model is released under the Apache 2.0 License, following the base Gemma model license.

🙏 Acknowledgments

  • Google DeepMind for Gemma 3 and Tunix
  • Google for powerful TPU Pod

Last Updated: December 2025

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for batuhanozkose/browser-control-gemma

Adapter
(159)
this model