|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| GRPO training for Sudoku with TextArena environment.
|
|
|
| Setup (Option A - Install from HF Space, recommended):
|
|
|
| ```sh
|
| uv pip install git+https://huggingface.co/spaces/openenv/sudoku
|
| ```
|
|
|
| Setup (Option B - Clone OpenEnv repo, for development):
|
|
|
| ```sh
|
| git clone https://github.com/meta-pytorch/OpenEnv.git
|
| cd OpenEnv/envs/textarena_env
|
| uv pip install -e .
|
| ```
|
|
|
| # Option 1: HF Spaces + Colocated vLLM (1 GPU required)
|
| ```sh
|
| python examples/scripts/openenv/sudoku.py --vllm-mode colocate
|
| ```
|
|
|
| # Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
|
|
|
| # Spin up vLLM server (Terminal 1)
|
| ```sh
|
| CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
|
| ```
|
|
|
| # Run training (Terminal 2)
|
| ```sh
|
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/sudoku.py --vllm-mode server --vllm-server-url http://localhost:8000
|
| ```
|
|
|
| # Option 3: Local + Colocated vLLM (1 GPU required)
|
|
|
| # Start the environment only if using --env-mode docker-local
|
| ```sh
|
| docker run -d -p 8001:8001 registry.hf.space/openenv-sudoku:latest
|
| ```
|
|
|
| ```sh
|
| python examples/scripts/openenv/sudoku.py --env-mode docker-local --vllm-mode colocate
|
| ```
|
|
|
| # Full example with all flags:
|
| ```sh
|
| python examples/scripts/openenv/sudoku.py \
|
| --vllm-mode colocate \
|
| --env-mode space \
|
| --env-host https://openenv-sudoku.hf.space \
|
| --num-generations 8 \
|
| --per-device-batch-size 1 \
|
| --max-turns 100 \
|
| --gradient-accumulation-steps 8 \
|
| --difficulty easy \
|
| --dataset-size 100
|
| ```
|
| """
|
|
|
| from __future__ import annotations
|
|
|
|
|
| import argparse
|
| import sys
|
| import time
|
| from collections import defaultdict
|
| from datetime import datetime
|
| from pathlib import Path
|
|
|
| from datasets import Dataset
|
|
|
| from trl import GRPOConfig, GRPOTrainer, RichProgressCallback
|
|
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent / "src"))
|
|
|
| from textarena_env import TextArenaAction, TextArenaEnv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(description="GRPO training for Sudoku")
|
|
|
|
|
| parser.add_argument("--model-id", default="Qwen/Qwen3-1.7B")
|
|
|
|
|
| parser.add_argument("--env-host", type=str, default="https://openenv-sudoku.hf.space")
|
| parser.add_argument("--env-port", type=int, default=8001)
|
| parser.add_argument("--env-mode", choices=["docker-local", "docker-image", "docker-hub", "space"], default="space")
|
| parser.add_argument("--env-image", type=str, default="textarena-env:latest")
|
|
|
|
|
| parser.add_argument("--system-prompt-path", default="sudoku_prompt.txt")
|
| parser.add_argument("--dataset-prompt", default="Play Sudoku like an expert.")
|
| parser.add_argument("--dataset-size", type=int, default=1000)
|
|
|
|
|
| parser.add_argument("--max-turns", type=int, default=100)
|
| parser.add_argument(
|
| "--difficulty",
|
| type=str,
|
| choices=["easy", "medium", "hard"],
|
| default="easy",
|
| help="Training difficulty: easy=guaranteed+options, medium=only options, hard=no hints",
|
| )
|
| parser.add_argument(
|
| "--api-delay", type=float, default=0.0, help="Delay in seconds between API calls to avoid rate limiting"
|
| )
|
|
|
|
|
| parser.add_argument("--temperature", type=float, default=0.8)
|
| parser.add_argument("--top-k", type=int, default=10)
|
| parser.add_argument("--top-p", type=float, default=None, help="Top-p sampling parameter")
|
|
|
|
|
| parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| parser.add_argument("--weight-decay", type=float, default=0.0)
|
| parser.add_argument("--gradient-accumulation-steps", type=int, default=64)
|
| parser.add_argument("--warmup-steps", type=int, default=20)
|
| parser.add_argument("--per-device-batch-size", type=int, default=1)
|
| parser.add_argument("--num-generations", type=int, default=8)
|
| parser.add_argument("--num-epochs", type=int, default=1)
|
| parser.add_argument("--max-completion-length", type=int, default=16384)
|
|
|
|
|
| parser.add_argument("--save-interval", type=int, default=10)
|
| parser.add_argument("--save-total-limit", type=int, default=None)
|
| parser.add_argument("--output-dir", default=None)
|
|
|
|
|
| parser.add_argument("--run-name", default=None)
|
| parser.add_argument("--project", default=None)
|
| parser.add_argument("--trackio-space-id", default="Sudoku-GRPO")
|
| parser.add_argument("--logging-steps", type=int, default=1)
|
| parser.add_argument(
|
| "--gradient-checkpointing",
|
| action=argparse.BooleanOptionalAction,
|
| default=True,
|
| help="Enable gradient checkpointing to save memory",
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--use-lora", action="store_true", default=False, help="Use LoRA for memory-efficient training"
|
| )
|
| parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
| parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
|
|
|
|
|
| parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate")
|
| parser.add_argument("--vllm-server-url", type=str, default="http://localhost:8000")
|
| parser.add_argument("--vllm-gpu-memory-utilization", type=float, default=0.15)
|
|
|
| return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def resolve_system_prompt(path: str) -> str:
|
| prompt_path = Path(path)
|
| if not prompt_path.is_file():
|
| prompt_path = Path(__file__).parent / path
|
| return prompt_path.read_text()
|
|
|
|
|
| def sanitize_name(name: str) -> str:
|
| return name.replace("/", "-")
|
|
|
|
|
| def is_valid_board_state(board_str: str) -> bool:
|
| """Check if the string contains an actual Sudoku board."""
|
| return "R1" in board_str and "R9" in board_str and "|" in board_str
|
|
|
|
|
| def parse_board(board_str: str) -> list[list[int]]:
|
| """Parse board string into 9x9 grid (0 = empty)."""
|
| grid = [[0] * 9 for _ in range(9)]
|
| if not is_valid_board_state(board_str):
|
| return grid
|
|
|
| for line in board_str.split("\n"):
|
| line_stripped = line.strip()
|
| if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
|
| row = int(line_stripped[1]) - 1
|
| cell_part = line_stripped[2:]
|
| col = 0
|
| for char in cell_part:
|
| if char == ".":
|
| grid[row][col] = 0
|
| col += 1
|
| elif char.isdigit():
|
| grid[row][col] = int(char)
|
| col += 1
|
| return grid
|
|
|
|
|
| def count_filled_cells(board_str: str) -> int:
|
| """Count the number of filled cells in the board."""
|
| if not is_valid_board_state(board_str):
|
| return 0
|
| grid = parse_board(board_str)
|
| return sum(1 for row in grid for cell in row if cell != 0)
|
|
|
|
|
| def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]:
|
| """Get valid numbers for a cell based on Sudoku rules."""
|
| if grid[row][col] != 0:
|
| return set()
|
|
|
| used = set()
|
|
|
|
|
| for c in range(9):
|
| if grid[row][c] != 0:
|
| used.add(grid[row][c])
|
|
|
|
|
| for r in range(9):
|
| if grid[r][col] != 0:
|
| used.add(grid[r][col])
|
|
|
|
|
| box_row, box_col = 3 * (row // 3), 3 * (col // 3)
|
| for r in range(box_row, box_row + 3):
|
| for c in range(box_col, box_col + 3):
|
| if grid[r][c] != 0:
|
| used.add(grid[r][c])
|
|
|
| return set(range(1, 10)) - used
|
|
|
|
|
| def extract_empty_cells_with_candidates(
|
| board_str: str, sort_by_difficulty: bool = True
|
| ) -> list[tuple[int, int, set[int]]]:
|
| """Extract empty cells with their valid candidate numbers.
|
|
|
| Args:
|
| sort_by_difficulty: If True, sort by number of candidates (easiest first).
|
| If False, keep natural order (top-left to bottom-right).
|
| """
|
| grid = parse_board(board_str)
|
| cells_with_candidates = []
|
|
|
| for row in range(9):
|
| for col in range(9):
|
| if grid[row][col] == 0:
|
| candidates = get_valid_numbers(grid, row, col)
|
| cells_with_candidates.append((row + 1, col + 1, candidates))
|
|
|
| if sort_by_difficulty:
|
|
|
| cells_with_candidates.sort(key=lambda x: len(x[2]))
|
|
|
| return cells_with_candidates
|
|
|
|
|
| def extract_empty_cells(board_str: str) -> list[tuple[int, int]]:
|
| """Extract list of empty cells (row, col) from board string."""
|
| empty_cells = []
|
| if not is_valid_board_state(board_str):
|
| return empty_cells
|
|
|
| for line in board_str.split("\n"):
|
| line_stripped = line.strip()
|
| if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
|
| row = int(line_stripped[1])
|
| cell_part = line_stripped[2:]
|
| col = 0
|
| for char in cell_part:
|
| if char == ".":
|
| col += 1
|
| empty_cells.append((row, col))
|
| elif char.isdigit():
|
| col += 1
|
| return empty_cells
|
|
|
|
|
| def extract_board_only(text: str) -> str:
|
| """Extract just the Sudoku grid from a message."""
|
| if not text:
|
| return ""
|
|
|
| lines = text.split("\n")
|
| board_lines = []
|
| in_board = False
|
|
|
| for line in lines:
|
| stripped = line.strip()
|
| if stripped.startswith("C1") or (
|
| stripped and stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit()
|
| ):
|
| in_board = True
|
| if in_board and (stripped.startswith("-") or stripped.startswith("R") or stripped.startswith("C1")):
|
| board_lines.append(line)
|
| elif (
|
| in_board
|
| and stripped
|
| and not stripped.startswith("-")
|
| and not (stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit())
|
| ):
|
| break
|
|
|
| return "\n".join(board_lines) if board_lines else ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def reward_empty_cell(environments, **kwargs) -> list[float]:
|
| """Reward for targeting empty cells (learn to pick valid positions first)."""
|
| return [env.empty_cell_reward for env in environments]
|
|
|
|
|
| def reward_valid_moves(environments, **kwargs) -> list[float]:
|
| """Reward for making valid moves."""
|
| return [env.valid_move_reward for env in environments]
|
|
|
|
|
| def reward_correct(environments, **kwargs) -> list[float]:
|
| """Reward for solving the puzzle."""
|
| return [env.correct_reward for env in environments]
|
|
|
|
|
| def reward_repetition(environments, **kwargs) -> list[float]:
|
| """Penalty for repeating moves."""
|
| return [env.repetition_reward for env in environments]
|
|
|
|
|
| def reward_progress(environments, **kwargs) -> list[float]:
|
| """Reward for filling more cells in the board."""
|
| return [env.progress_reward for env in environments]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
|
|
|
|
| if args.env_mode == "docker-local":
|
| env_url = f"http://{args.env_host}:{args.env_port}"
|
| elif args.env_mode == "docker-image":
|
| _bootstrap = TextArenaEnv.from_docker_image(args.env_image)
|
| env_url = _bootstrap.base_url
|
| elif args.env_mode == "docker-hub":
|
| _bootstrap = TextArenaEnv.from_hub(args.env_image)
|
| env_url = _bootstrap.base_url
|
| elif args.env_mode == "space":
|
| env_url = args.env_host
|
| else:
|
| raise ValueError(f"Unknown environment mode: {args.env_mode}")
|
|
|
| print(f"Environment: {args.env_mode} ({env_url})")
|
|
|
| system_prompt = resolve_system_prompt(args.system_prompt_path)
|
| dataset = Dataset.from_dict(
|
| {
|
| "prompt": [
|
| [
|
| {"role": "system", "content": system_prompt},
|
| {"role": "user", "content": args.dataset_prompt},
|
| ]
|
| ]
|
| * args.dataset_size
|
| }
|
| )
|
|
|
|
|
| difficulty = args.difficulty
|
| max_turns = args.max_turns
|
| api_delay = args.api_delay
|
|
|
| class SudokuEnv:
|
| def __init__(self):
|
| self.client = TextArenaEnv(base_url=env_url)
|
| self._difficulty = difficulty
|
| self._max_turns = max_turns
|
| self._api_delay = api_delay
|
| self._reset_state()
|
|
|
| def _reset_state(self):
|
| self._move_counts: defaultdict[str, int] = defaultdict(int)
|
| self._successful_moves: list[str] = []
|
| self._failed_moves: list[str] = []
|
| self._valid_move_scores: list[float] = []
|
| self._empty_cell_scores: list[float] = []
|
| self._correct_scores: list[float] = []
|
| self._repetition_scores: list[float] = []
|
| self._last_board_state = ""
|
| self._initial_filled = 0
|
| self._max_filled = 0
|
| self._turn = 0
|
| self._done = False
|
|
|
| def reset(self, **kwargs) -> str:
|
| self._reset_state()
|
| result = self.client.reset()
|
| time.sleep(self._api_delay)
|
| observation = result.observation
|
| self._done = result.done
|
|
|
|
|
| self._last_full_content = observation.messages[0].content if observation.messages else ""
|
|
|
| if is_valid_board_state(self._last_full_content):
|
| self._last_board_state = self._last_full_content
|
| self._initial_filled = count_filled_cells(self._last_board_state)
|
| self._max_filled = self._initial_filled
|
|
|
| board = extract_board_only(self._last_board_state) if self._last_board_state else "No board available."
|
| hints = self._format_hints()
|
| return f"Step 0. Progress: 0 cells filled.\n\nBoard:\n{board}{hints}"
|
|
|
| def place(self, row: int, col: int, number: int) -> str:
|
| """Place a number on the Sudoku board.
|
|
|
| Args:
|
| row: Row number (1-9).
|
| col: Column number (1-9).
|
| number: Number to place (1-9).
|
|
|
| Returns:
|
| The result of the move and updated board state.
|
| """
|
| if self._done:
|
| raise ValueError("Game is over. No more moves allowed.")
|
|
|
| self._turn += 1
|
| move = f"[{row} {col} {number}]"
|
|
|
|
|
| result = self.client.step(TextArenaAction(message=move))
|
| time.sleep(self._api_delay)
|
| observation = result.observation
|
| correct_score = float(result.reward or 0.0)
|
| self._done = result.done
|
|
|
|
|
| full_content = observation.messages[0].content if observation.messages else ""
|
| new_content = full_content[len(self._last_full_content) :]
|
| self._last_full_content = full_content
|
|
|
| new_content_lower = new_content.lower()
|
| env_says_invalid = any(
|
| kw in new_content_lower for kw in ["invalid", "error", "cannot", "already", "violation", "lost"]
|
| )
|
| got_warning = "please resubmit" in new_content_lower or "avoid penalties" in new_content_lower
|
|
|
|
|
| if self._last_board_state:
|
| empty_cells = extract_empty_cells(self._last_board_state)
|
| targets_empty = (row, col) in empty_cells
|
| else:
|
| empty_cells = []
|
| targets_empty = True
|
|
|
| is_valid = not env_says_invalid and targets_empty
|
|
|
|
|
| empty_cell_score = 1.0 if targets_empty else -1.0
|
|
|
|
|
| is_new_move = self._move_counts[move] == 0
|
| repetition_count = self._move_counts[move]
|
| self._move_counts[move] += 1
|
| repetition_score = -min(2 ** (repetition_count - 1), 10.0) if repetition_count > 0 else 0.0
|
|
|
|
|
| if is_valid and is_new_move:
|
| valid_move_score = 1.0
|
| self._successful_moves.append(move)
|
| elif got_warning:
|
| valid_move_score = -0.5
|
| self._failed_moves.append(move)
|
| else:
|
| valid_move_score = 0.0
|
|
|
|
|
| if is_valid and is_valid_board_state(new_content):
|
| self._last_board_state = new_content
|
| current_filled = count_filled_cells(self._last_board_state)
|
| if current_filled > self._max_filled:
|
| self._max_filled = current_filled
|
|
|
| self._valid_move_scores.append(valid_move_score)
|
| self._empty_cell_scores.append(empty_cell_score)
|
| self._correct_scores.append(correct_score)
|
| self._repetition_scores.append(repetition_score)
|
|
|
|
|
| if self._turn >= self._max_turns:
|
| self._done = True
|
|
|
|
|
| board = extract_board_only(self._last_board_state) if self._last_board_state else "No board available."
|
| status = "valid" if is_valid else "invalid"
|
| cells_filled = len(self._successful_moves)
|
| progress = f"Step {self._turn}. Progress: {cells_filled} cells filled."
|
| hints = self._format_hints()
|
|
|
| if self._done:
|
| return f"Move {move}: {status}. Game over.\n{progress}\n\nFinal board:\n{board}"
|
| return f"Move {move}: {status}\n{progress}\n\nBoard:\n{board}{hints}"
|
|
|
| def _format_hints(self) -> str:
|
| parts = []
|
|
|
|
|
| all_tried = self._successful_moves + self._failed_moves
|
| if all_tried:
|
| parts.append(f"\nMOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}")
|
|
|
| if not self._last_board_state:
|
| return "\n".join(parts)
|
|
|
| if self._difficulty == "easy":
|
| cells = extract_empty_cells_with_candidates(self._last_board_state, sort_by_difficulty=True)
|
| if cells:
|
| guaranteed = []
|
| other = []
|
| for r, c, candidates in cells[:10]:
|
| if len(candidates) == 1:
|
| guaranteed.append(f"[{r} {c} {list(candidates)[0]}]")
|
| elif len(candidates) <= 3:
|
| nums = ",".join(str(n) for n in sorted(candidates))
|
| other.append(f"({r},{c})->{nums}")
|
| if guaranteed:
|
| parts.append(f"\nGUARANTEED MOVES: {', '.join(guaranteed[:5])}")
|
| if other:
|
| parts.append(f"Other options: {' | '.join(other[:5])}")
|
|
|
| elif self._difficulty == "medium":
|
| cells = extract_empty_cells_with_candidates(self._last_board_state, sort_by_difficulty=False)
|
| if cells:
|
| cell_hints = []
|
| for r, c, candidates in cells[:10]:
|
| nums = ",".join(str(n) for n in sorted(candidates))
|
| cell_hints.append(f"({r},{c})->{nums}")
|
| parts.append(f"\nEmpty cells: {' | '.join(cell_hints)}")
|
|
|
| return "\n".join(parts)
|
|
|
|
|
|
|
|
|
| @property
|
| def correct_reward(self) -> float:
|
| return self._correct_scores[-1] if self._correct_scores else 0.0
|
|
|
| @property
|
| def valid_move_reward(self) -> float:
|
| return sum(self._valid_move_scores) / len(self._valid_move_scores) if self._valid_move_scores else 0.0
|
|
|
| @property
|
| def empty_cell_reward(self) -> float:
|
| return sum(self._empty_cell_scores) / len(self._empty_cell_scores) if self._empty_cell_scores else 0.0
|
|
|
| @property
|
| def repetition_reward(self) -> float:
|
| return sum(self._repetition_scores) / len(self._repetition_scores) if self._repetition_scores else 0.0
|
|
|
| @property
|
| def progress_reward(self) -> float:
|
| remaining = 81 - self._initial_filled
|
| if remaining > 0:
|
| return (self._max_filled - self._initial_filled) / remaining
|
| return 1.0
|
|
|
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| output_dir = Path(args.output_dir or f"outputs/sudoku-grpo-{sanitize_name(args.model_id)}-{timestamp}")
|
|
|
| grpo_config = GRPOConfig(
|
| use_vllm=True,
|
| vllm_mode=args.vllm_mode,
|
| vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
|
| vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization or 0.2,
|
| output_dir=str(output_dir),
|
| num_train_epochs=args.num_epochs,
|
| learning_rate=args.learning_rate,
|
| weight_decay=args.weight_decay,
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| per_device_train_batch_size=args.per_device_batch_size,
|
| warmup_steps=args.warmup_steps,
|
| num_generations=args.num_generations,
|
| max_completion_length=args.max_completion_length,
|
| logging_steps=args.logging_steps,
|
| save_strategy="steps",
|
| save_steps=args.save_interval,
|
| save_total_limit=args.save_total_limit,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| top_p=args.top_p,
|
| report_to="trackio",
|
| log_completions=True,
|
| num_completions_to_print=1,
|
| chat_template_kwargs={"enable_thinking": False},
|
| )
|
|
|
| grpo_config.run_name = args.run_name or f"run-{timestamp}"
|
| grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
|
| grpo_config.trackio_space_id = args.trackio_space_id
|
| grpo_config.gradient_checkpointing = args.gradient_checkpointing
|
|
|
| peft_config = None
|
| if args.use_lora:
|
| from peft import LoraConfig
|
|
|
| peft_config = LoraConfig(r=args.lora_r, lora_alpha=args.lora_alpha, task_type="CAUSAL_LM")
|
|
|
| trainer = GRPOTrainer(
|
| model=args.model_id,
|
| reward_funcs=[
|
| reward_empty_cell,
|
| reward_valid_moves,
|
| reward_repetition,
|
| reward_progress,
|
| reward_correct,
|
| ],
|
| peft_config=peft_config,
|
| train_dataset=dataset,
|
| args=grpo_config,
|
| environment_factory=SudokuEnv,
|
| callbacks=[RichProgressCallback()],
|
| )
|
|
|
| print(f"Starting GRPO training: {args.num_generations} generations, {args.max_turns} max turns")
|
| trainer.train()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|