# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # /// script # dependencies = [ # "trl[vllm,peft]", # "trackio", # "kernels", # "openenv-textarena @ git+https://huggingface.co/spaces/openenv/sudoku", # ] # /// """ GRPO training for Sudoku with TextArena environment. Setup (Option A - Install from HF Space, recommended): ```sh uv pip install git+https://huggingface.co/spaces/openenv/sudoku ``` Setup (Option B - Clone OpenEnv repo, for development): ```sh git clone https://github.com/meta-pytorch/OpenEnv.git cd OpenEnv/envs/textarena_env uv pip install -e . ``` # Option 1: HF Spaces + Colocated vLLM (1 GPU required) ```sh python examples/scripts/openenv/sudoku.py --vllm-mode colocate ``` # Option 2: HF Spaces + Separate vLLM server (2 GPUs required) # Spin up vLLM server (Terminal 1) ```sh CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 ``` # Run training (Terminal 2) ```sh CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/sudoku.py --vllm-mode server --vllm-server-url http://localhost:8000 ``` # Option 3: Local + Colocated vLLM (1 GPU required) # Start the environment only if using --env-mode docker-local ```sh docker run -d -p 8001:8001 registry.hf.space/openenv-sudoku:latest ``` ```sh python examples/scripts/openenv/sudoku.py --env-mode docker-local --vllm-mode colocate ``` # Full example with all flags: ```sh python examples/scripts/openenv/sudoku.py \ --vllm-mode colocate \ --env-mode space \ --env-host https://openenv-sudoku.hf.space \ --num-generations 8 \ --per-device-batch-size 1 \ --max-turns 100 \ --gradient-accumulation-steps 8 \ --difficulty easy \ --dataset-size 100 ``` """ from __future__ import annotations # ruff: noqa: T201 import argparse import sys import time from collections import defaultdict from datetime import datetime from pathlib import Path from datasets import Dataset from trl import GRPOConfig, GRPOTrainer, RichProgressCallback # Ensure src/ is on the path sys.path.insert(0, str(Path(__file__).parent / "src")) from textarena_env import TextArenaAction, TextArenaEnv # --------------------------------------------------------------------------- # Argument parsing # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="GRPO training for Sudoku") # Model parser.add_argument("--model-id", default="Qwen/Qwen3-1.7B") # Environment parser.add_argument("--env-host", type=str, default="https://openenv-sudoku.hf.space") parser.add_argument("--env-port", type=int, default=8001) parser.add_argument("--env-mode", choices=["docker-local", "docker-image", "docker-hub", "space"], default="space") parser.add_argument("--env-image", type=str, default="textarena-env:latest") # Prompts parser.add_argument("--system-prompt-path", default="sudoku_prompt.txt") parser.add_argument("--dataset-prompt", default="Play Sudoku like an expert.") parser.add_argument("--dataset-size", type=int, default=1000) # Game settings parser.add_argument("--max-turns", type=int, default=100) parser.add_argument( "--difficulty", type=str, choices=["easy", "medium", "hard"], default="easy", help="Training difficulty: easy=guaranteed+options, medium=only options, hard=no hints", ) parser.add_argument( "--api-delay", type=float, default=0.0, help="Delay in seconds between API calls to avoid rate limiting" ) # Sampling parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-k", type=int, default=10) parser.add_argument("--top-p", type=float, default=None, help="Top-p sampling parameter") # Training parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--weight-decay", type=float, default=0.0) parser.add_argument("--gradient-accumulation-steps", type=int, default=64) parser.add_argument("--warmup-steps", type=int, default=20) parser.add_argument("--per-device-batch-size", type=int, default=1) parser.add_argument("--num-generations", type=int, default=8) parser.add_argument("--num-epochs", type=int, default=1) parser.add_argument("--max-completion-length", type=int, default=16384) # Checkpoints parser.add_argument("--save-interval", type=int, default=10) parser.add_argument("--save-total-limit", type=int, default=None) parser.add_argument("--output-dir", default=None) # Logging parser.add_argument("--run-name", default=None) parser.add_argument("--project", default=None) parser.add_argument("--trackio-space-id", default="Sudoku-GRPO") parser.add_argument("--logging-steps", type=int, default=1) parser.add_argument( "--gradient-checkpointing", action=argparse.BooleanOptionalAction, default=True, help="Enable gradient checkpointing to save memory", ) # LoRA / PEFT parser.add_argument( "--use-lora", action="store_true", default=False, help="Use LoRA for memory-efficient training" ) parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank") parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha") # vLLM parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate") parser.add_argument("--vllm-server-url", type=str, default="http://localhost:8000") parser.add_argument("--vllm-gpu-memory-utilization", type=float, default=0.15) return parser.parse_args() # --------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------- def resolve_system_prompt(path: str) -> str: prompt_path = Path(path) if not prompt_path.is_file(): prompt_path = Path(__file__).parent / path return prompt_path.read_text() def sanitize_name(name: str) -> str: return name.replace("/", "-") def is_valid_board_state(board_str: str) -> bool: """Check if the string contains an actual Sudoku board.""" return "R1" in board_str and "R9" in board_str and "|" in board_str def parse_board(board_str: str) -> list[list[int]]: """Parse board string into 9x9 grid (0 = empty).""" grid = [[0] * 9 for _ in range(9)] if not is_valid_board_state(board_str): return grid for line in board_str.split("\n"): line_stripped = line.strip() if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit(): row = int(line_stripped[1]) - 1 # 0-indexed cell_part = line_stripped[2:] col = 0 for char in cell_part: if char == ".": grid[row][col] = 0 col += 1 elif char.isdigit(): grid[row][col] = int(char) col += 1 return grid def count_filled_cells(board_str: str) -> int: """Count the number of filled cells in the board.""" if not is_valid_board_state(board_str): return 0 grid = parse_board(board_str) return sum(1 for row in grid for cell in row if cell != 0) def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]: """Get valid numbers for a cell based on Sudoku rules.""" if grid[row][col] != 0: return set() used = set() # Check row for c in range(9): if grid[row][c] != 0: used.add(grid[row][c]) # Check column for r in range(9): if grid[r][col] != 0: used.add(grid[r][col]) # Check 3x3 box box_row, box_col = 3 * (row // 3), 3 * (col // 3) for r in range(box_row, box_row + 3): for c in range(box_col, box_col + 3): if grid[r][c] != 0: used.add(grid[r][c]) return set(range(1, 10)) - used def extract_empty_cells_with_candidates( board_str: str, sort_by_difficulty: bool = True ) -> list[tuple[int, int, set[int]]]: """Extract empty cells with their valid candidate numbers. Args: sort_by_difficulty: If True, sort by number of candidates (easiest first). If False, keep natural order (top-left to bottom-right). """ grid = parse_board(board_str) cells_with_candidates = [] for row in range(9): for col in range(9): if grid[row][col] == 0: candidates = get_valid_numbers(grid, row, col) cells_with_candidates.append((row + 1, col + 1, candidates)) # 1-indexed if sort_by_difficulty: # Sort by number of candidates (easiest first = naked singles) cells_with_candidates.sort(key=lambda x: len(x[2])) return cells_with_candidates def extract_empty_cells(board_str: str) -> list[tuple[int, int]]: """Extract list of empty cells (row, col) from board string.""" empty_cells = [] if not is_valid_board_state(board_str): return empty_cells for line in board_str.split("\n"): line_stripped = line.strip() if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit(): row = int(line_stripped[1]) cell_part = line_stripped[2:] col = 0 for char in cell_part: if char == ".": col += 1 empty_cells.append((row, col)) elif char.isdigit(): col += 1 return empty_cells def extract_board_only(text: str) -> str: """Extract just the Sudoku grid from a message.""" if not text: return "" lines = text.split("\n") board_lines = [] in_board = False for line in lines: stripped = line.strip() if stripped.startswith("C1") or ( stripped and stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit() ): in_board = True if in_board and (stripped.startswith("-") or stripped.startswith("R") or stripped.startswith("C1")): board_lines.append(line) elif ( in_board and stripped and not stripped.startswith("-") and not (stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit()) ): break return "\n".join(board_lines) if board_lines else "" # --------------------------------------------------------------------------- # Reward functions # --------------------------------------------------------------------------- def reward_empty_cell(environments, **kwargs) -> list[float]: """Reward for targeting empty cells (learn to pick valid positions first).""" return [env.empty_cell_reward for env in environments] def reward_valid_moves(environments, **kwargs) -> list[float]: """Reward for making valid moves.""" return [env.valid_move_reward for env in environments] def reward_correct(environments, **kwargs) -> list[float]: """Reward for solving the puzzle.""" return [env.correct_reward for env in environments] def reward_repetition(environments, **kwargs) -> list[float]: """Penalty for repeating moves.""" return [env.repetition_reward for env in environments] def reward_progress(environments, **kwargs) -> list[float]: """Reward for filling more cells in the board.""" return [env.progress_reward for env in environments] # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: args = parse_args() # Setup environment — all modes resolve to env_url if args.env_mode == "docker-local": env_url = f"http://{args.env_host}:{args.env_port}" elif args.env_mode == "docker-image": _bootstrap = TextArenaEnv.from_docker_image(args.env_image) env_url = _bootstrap.base_url elif args.env_mode == "docker-hub": _bootstrap = TextArenaEnv.from_hub(args.env_image) env_url = _bootstrap.base_url elif args.env_mode == "space": env_url = args.env_host else: raise ValueError(f"Unknown environment mode: {args.env_mode}") print(f"Environment: {args.env_mode} ({env_url})") system_prompt = resolve_system_prompt(args.system_prompt_path) dataset = Dataset.from_dict( { "prompt": [ [ {"role": "system", "content": system_prompt}, {"role": "user", "content": args.dataset_prompt}, ] ] * args.dataset_size } ) # Capture args for use in the environment class closure difficulty = args.difficulty max_turns = args.max_turns api_delay = args.api_delay class SudokuEnv: def __init__(self): self.client = TextArenaEnv(base_url=env_url) self._difficulty = difficulty self._max_turns = max_turns self._api_delay = api_delay self._reset_state() def _reset_state(self): self._move_counts: defaultdict[str, int] = defaultdict(int) self._successful_moves: list[str] = [] self._failed_moves: list[str] = [] self._valid_move_scores: list[float] = [] self._empty_cell_scores: list[float] = [] self._correct_scores: list[float] = [] self._repetition_scores: list[float] = [] self._last_board_state = "" self._initial_filled = 0 self._max_filled = 0 self._turn = 0 self._done = False def reset(self, **kwargs) -> str: self._reset_state() result = self.client.reset() time.sleep(self._api_delay) observation = result.observation self._done = result.done # Store full message content for diffing (messages are cumulative) self._last_full_content = observation.messages[0].content if observation.messages else "" if is_valid_board_state(self._last_full_content): self._last_board_state = self._last_full_content self._initial_filled = count_filled_cells(self._last_board_state) self._max_filled = self._initial_filled board = extract_board_only(self._last_board_state) if self._last_board_state else "No board available." hints = self._format_hints() return f"Step 0. Progress: 0 cells filled.\n\nBoard:\n{board}{hints}" def place(self, row: int, col: int, number: int) -> str: """Place a number on the Sudoku board. Args: row: Row number (1-9). col: Column number (1-9). number: Number to place (1-9). Returns: The result of the move and updated board state. """ if self._done: raise ValueError("Game is over. No more moves allowed.") self._turn += 1 move = f"[{row} {col} {number}]" # Step environment result = self.client.step(TextArenaAction(message=move)) time.sleep(self._api_delay) observation = result.observation correct_score = float(result.reward or 0.0) self._done = result.done # Only check the NEW content for feedback (messages are cumulative) full_content = observation.messages[0].content if observation.messages else "" new_content = full_content[len(self._last_full_content) :] self._last_full_content = full_content new_content_lower = new_content.lower() env_says_invalid = any( kw in new_content_lower for kw in ["invalid", "error", "cannot", "already", "violation", "lost"] ) got_warning = "please resubmit" in new_content_lower or "avoid penalties" in new_content_lower # Also verify against our own board state: placing on a non-empty cell is always invalid if self._last_board_state: empty_cells = extract_empty_cells(self._last_board_state) targets_empty = (row, col) in empty_cells else: empty_cells = [] targets_empty = True # Can't verify, assume valid is_valid = not env_says_invalid and targets_empty # Empty cell score: did the model target an empty cell? empty_cell_score = 1.0 if targets_empty else -1.0 # Repetition tracking is_new_move = self._move_counts[move] == 0 repetition_count = self._move_counts[move] self._move_counts[move] += 1 repetition_score = -min(2 ** (repetition_count - 1), 10.0) if repetition_count > 0 else 0.0 # Valid move score if is_valid and is_new_move: valid_move_score = 1.0 self._successful_moves.append(move) elif got_warning: valid_move_score = -0.5 self._failed_moves.append(move) else: valid_move_score = 0.0 # Update board state from new content if is_valid and is_valid_board_state(new_content): self._last_board_state = new_content current_filled = count_filled_cells(self._last_board_state) if current_filled > self._max_filled: self._max_filled = current_filled self._valid_move_scores.append(valid_move_score) self._empty_cell_scores.append(empty_cell_score) self._correct_scores.append(correct_score) self._repetition_scores.append(repetition_score) # Enforce max turns if self._turn >= self._max_turns: self._done = True # Build response board = extract_board_only(self._last_board_state) if self._last_board_state else "No board available." status = "valid" if is_valid else "invalid" cells_filled = len(self._successful_moves) progress = f"Step {self._turn}. Progress: {cells_filled} cells filled." hints = self._format_hints() if self._done: return f"Move {move}: {status}. Game over.\n{progress}\n\nFinal board:\n{board}" return f"Move {move}: {status}\n{progress}\n\nBoard:\n{board}{hints}" def _format_hints(self) -> str: parts = [] # Already tried moves (avoid repetitions) all_tried = self._successful_moves + self._failed_moves if all_tried: parts.append(f"\nMOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}") if not self._last_board_state: return "\n".join(parts) if self._difficulty == "easy": cells = extract_empty_cells_with_candidates(self._last_board_state, sort_by_difficulty=True) if cells: guaranteed = [] other = [] for r, c, candidates in cells[:10]: if len(candidates) == 1: guaranteed.append(f"[{r} {c} {list(candidates)[0]}]") elif len(candidates) <= 3: nums = ",".join(str(n) for n in sorted(candidates)) other.append(f"({r},{c})->{nums}") if guaranteed: parts.append(f"\nGUARANTEED MOVES: {', '.join(guaranteed[:5])}") if other: parts.append(f"Other options: {' | '.join(other[:5])}") elif self._difficulty == "medium": cells = extract_empty_cells_with_candidates(self._last_board_state, sort_by_difficulty=False) if cells: cell_hints = [] for r, c, candidates in cells[:10]: nums = ",".join(str(n) for n in sorted(candidates)) cell_hints.append(f"({r},{c})->{nums}") parts.append(f"\nEmpty cells: {' | '.join(cell_hints)}") return "\n".join(parts) # Reward properties — properties are not detected by inspect.ismethod, # so they won't be exposed as tools. @property def correct_reward(self) -> float: return self._correct_scores[-1] if self._correct_scores else 0.0 @property def valid_move_reward(self) -> float: return sum(self._valid_move_scores) / len(self._valid_move_scores) if self._valid_move_scores else 0.0 @property def empty_cell_reward(self) -> float: return sum(self._empty_cell_scores) / len(self._empty_cell_scores) if self._empty_cell_scores else 0.0 @property def repetition_reward(self) -> float: return sum(self._repetition_scores) / len(self._repetition_scores) if self._repetition_scores else 0.0 @property def progress_reward(self) -> float: remaining = 81 - self._initial_filled if remaining > 0: return (self._max_filled - self._initial_filled) / remaining return 1.0 timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") output_dir = Path(args.output_dir or f"outputs/sudoku-grpo-{sanitize_name(args.model_id)}-{timestamp}") grpo_config = GRPOConfig( use_vllm=True, vllm_mode=args.vllm_mode, vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization or 0.2, output_dir=str(output_dir), num_train_epochs=args.num_epochs, learning_rate=args.learning_rate, weight_decay=args.weight_decay, gradient_accumulation_steps=args.gradient_accumulation_steps, per_device_train_batch_size=args.per_device_batch_size, warmup_steps=args.warmup_steps, num_generations=args.num_generations, max_completion_length=args.max_completion_length, logging_steps=args.logging_steps, save_strategy="steps", save_steps=args.save_interval, save_total_limit=args.save_total_limit, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, report_to="trackio", log_completions=True, num_completions_to_print=1, chat_template_kwargs={"enable_thinking": False}, ) grpo_config.run_name = args.run_name or f"run-{timestamp}" grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" grpo_config.trackio_space_id = args.trackio_space_id grpo_config.gradient_checkpointing = args.gradient_checkpointing peft_config = None if args.use_lora: from peft import LoraConfig peft_config = LoraConfig(r=args.lora_r, lora_alpha=args.lora_alpha, task_type="CAUSAL_LM") trainer = GRPOTrainer( model=args.model_id, reward_funcs=[ reward_empty_cell, # Learn to pick empty cells reward_valid_moves, # Learn valid numbers reward_repetition, # Penalize repeating moves reward_progress, # Reward filling more cells reward_correct, # Solve the puzzle ], peft_config=peft_config, train_dataset=dataset, args=grpo_config, environment_factory=SudokuEnv, callbacks=[RichProgressCallback()], ) print(f"Starting GRPO training: {args.num_generations} generations, {args.max_turns} max turns") trainer.train() if __name__ == "__main__": main()