""" Reference Implementation: OpenEnv + GRPO Reinforcement Learning for 2048 Game ============================================================================== Extracted from the Unsloth / OpenEnv notebook: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/ OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb This file contains ALL code cells from the notebook, organized into sections. It serves as a reference for how to build an OpenEnv-based RL environment and connect it to GRPO training via TRL. KEY ARCHITECTURE: 1. OpenEnv provides a server-based game environment (2048 via OpenSpiel) 2. The LLM generates a Python *strategy function* (code-as-action) 3. The strategy function is executed against the environment 4. Three reward functions score the output: - function_works: Does the generated code parse and compile? - no_cheating: Does it only use stdlib imports? - strategy_succeeds: Does the strategy actually play the game well? 5. GRPO (from TRL) uses these rewards to train the model PROMPT/RESPONSE FORMAT: - Prompt asks the LLM to write a Python function `strategy(board)` that takes a list-of-lists board state and returns "0","1","2","3" (up/right/down/left) - Response is wrapped in ```python ... ``` backticks - The function is extracted, sandboxed, and executed against the live game REWARD STRUCTURE: - function_works: +1.0 if valid Python, -2.0 if no function / syntax error, -0.5 if exec fails - no_cheating: +1.0 if only stdlib imports, -20.0 if non-stdlib imports, -1.0 if no function - strategy_succeeds: +20.0 if reaches 2048, +2.0 if function runs but doesn't win, -1.0 on timeout, -3.0 on exception, 0 if function broken """ # ============================================================================= # CELL 1: Installation (pip installs - shown for reference, not executable here) # ============================================================================= """ %%capture import os, importlib.util !pip install --upgrade -qqq uv if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()): try: import numpy; get_numpy = f"numpy=={numpy.__version__}" except: get_numpy = "numpy" !uv pip install -qqq \\ "torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" trackio \\ "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \\ "unsloth[base] @ git+https://github.com/unslothai/unsloth" \\ git+https://github.com/triton-lang/triton.git@0add68262ab0a2e33b84524346cb27cbb2787356#subdirectory=python/triton_kernels elif importlib.util.find_spec("unsloth") is None: !uv pip install -qqq unsloth trackio !uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo """ # ============================================================================= # CELL 2: Install OpenEnv from source # ============================================================================= """ %%capture !pip install -qqq fastapi uvicorn requests open_spiel !pip install fastapi uvicorn requests !pip install open_spiel --prefer-binary !git clone https://github.com/meta-pytorch/OpenEnv.git > /dev/null 2>&1 %cd OpenEnv !git checkout 83dda10 """ import subprocess, sys, os from pathlib import Path # sys.path.insert(0, '.') # Add OpenEnv root for envs module # sys.path.insert(0, './src') # working_directory = str(Path.cwd().parent.absolute() / "OpenEnv") # ============================================================================= # CELL 3: Load the model with Unsloth # ============================================================================= import os from unsloth import FastLanguageModel import torch max_seq_length = 768 # Can increase for longer RL output lora_rank = 4 # Larger rank = smarter, but slower model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/gpt-oss-20b", load_in_4bit = True, max_seq_length = max_seq_length, offload_embedding = True, # Offload embeddings to save more VRAM ) # ============================================================================= # CELL 4: Apply LoRA adapters # ============================================================================= model = FastLanguageModel.get_peft_model( model, r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha = lora_rank * 2, # *2 speeds up training use_gradient_checkpointing = "unsloth", # Reduces memory usage random_state = 3407, ) # ============================================================================= # CELL 5: OpenEnv imports (environment-specific) # ============================================================================= from envs.openspiel_env import OpenSpielEnv from envs.openspiel_env.models import OpenSpielAction, OpenSpielObservation # ============================================================================= # CELL 6: OpenEnv process launch configuration # ============================================================================= global port global openenv_process port = 9000 openenv_process = None server = "envs.openspiel_env.server.app:app" environment = { **os.environ, "PYTHONPATH": f"{working_directory}/src", "OPENSPIEL_GAME": "2048", "OPENSPIEL_AGENT_PLAYER": "0", "OPENSPIEL_OPPONENT_POLICY": "random", } # Augment Unsloth's OpenEnv creation function import functools from unsloth import is_port_open, launch_openenv launch_openenv = functools.partial( launch_openenv, working_directory = working_directory, server = server, environment = environment, openenv_class = OpenSpielEnv, ) # ============================================================================= # CELL 7: Reset the environment and observe initial state # ============================================================================= port, openenv_process = launch_openenv(port, openenv_process) result = openenv_process.reset() current_state = result.observation # current_state is an OpenSpielObservation with: # .done -> bool # .reward -> float or None # .info_state -> list of floats (flat board) # .legal_actions -> list of ints (e.g. [0,1,2,3]) # .game_phase -> str # .current_player_id -> int # ============================================================================= # CELL 8: Convert flat info_state to 2D board # ============================================================================= import numpy as np def convert_to_board(current_state): n = len(current_state.info_state) size = int(np.sqrt(n)) board = np.array_split(np.array(current_state.info_state, dtype=int), size) board = [x.tolist() for x in board] return board, size # ============================================================================= # CELL 9: Pretty-print the 2048 board (collapsible in notebook) # ============================================================================= def render_board(obs, colors: bool = True, border: bool = True, dot_for_zero: bool = True) -> str: """ Pretty-print the board with colors that scale from 0 up to self.target. Uses ANSI 256-color codes (works in most terminals). Set colors=False to disable. """ import math b, size = convert_to_board(obs) mx = max((max(row) for row in b), default=0) cell_w = max(3, len(str(mx))) RESET = "\x1b[0m" # A smooth-ish gradient from cool -> warm GRAD = [33, 39, 45, 51, 50, 49, 48, 47, 46, 82, 118, 154, 190, 226, 220, 214, 208, 202, 196] ZERO_FG = 239 # dim gray def color_code(v: int) -> str: if not colors: return "" if v == 0: return f"\x1b[38;5;{ZERO_FG}m" t = max(2, 2048) try: r = max(0.0, min(1.0, math.log2(v) / math.log2(t))) except ValueError: r = 0.0 idx = int(round(r * (len(GRAD) - 1))) return f"\x1b[38;5;{GRAD[idx]}m" def fmt(v: int) -> str: s = "." if (v == 0 and dot_for_zero) else str(v) s = s.rjust(cell_w) return color_code(v) + s + (RESET if colors else "") def hline(left: str, mid: str, right: str) -> str: return left + mid.join("\u2500" * cell_w for _ in range(size)) + right rows = [] if border: rows.append(hline("\u250c", "\u252c", "\u2510")) for r in range(size): content = "\u2502".join(fmt(v) for v in b[r]) rows.append(("\u2502" + content + "\u2502") if border else content) if border: rows.append(hline("\u2514" if r == size - 1 else "\u251c", "\u2534" if r == size - 1 else "\u253c", "\u2518" if r == size - 1 else "\u2524")) return "\n".join(rows) # ============================================================================= # CELL 10: Demonstrate stepping through the environment # ============================================================================= # Action mapping: 0 = up, 1 = right, 2 = down, 3 = left action = OpenSpielAction(action_id=0, game_name="2048") result = openenv_process.step(action) current_state = result.observation print(render_board(current_state)) # ============================================================================= # CELL 11: RL Environment - Strategy execution with time limit # ============================================================================= from typing import Callable from unsloth import execute_with_time_limit import itertools def _execute_strategy(strategy, current_state: OpenSpielObservation): """ Execute a strategy function against the 2048 environment. The strategy receives a board (list of lists) and returns an action int. Runs until the game is done or the strategy fails. Returns (steps, whether_2048_was_reached). """ assert callable(strategy) steps = 0 total_reward = 0 while not current_state.done: board, size = convert_to_board(current_state) action = strategy(board) try: action = int(action) except: return steps, False steps += 1 if type(action) is not int or action not in current_state.legal_actions: return steps, max(itertools.chain.from_iterable(board)) == 2048 global port, openenv_process port, openenv_process = launch_openenv(port, openenv_process) action = OpenSpielAction(action_id=action, game_name="2048") result = openenv_process.step(action) current_state = result.observation if result.reward is not None: total_reward += result.reward return steps, max(itertools.chain.from_iterable(board)) == 2048 # Time-limited wrapper (2 seconds default, later changed to 5) @execute_with_time_limit(2) def execute_strategy(strategy: Callable, current_state: OpenSpielObservation): return _execute_strategy(strategy, current_state) # ============================================================================= # CELL 12: Test with a trivial strategy # ============================================================================= def always_move_left(board): return 3 # Reset OpenEnv to an initial state! port, openenv_process = launch_openenv(port, openenv_process) result = openenv_process.reset() current_state = result.observation try: steps, if_done = execute_strategy(always_move_left, current_state) except TimeoutError as e: print(f"Timed out with error = {str(e)}") print(f"steps={steps}, if_done={if_done}") # ============================================================================= # CELL 13: Extend time limit to 5 seconds for actual RL training # ============================================================================= @execute_with_time_limit(5) def execute_strategy(strategy: Callable, current_state: OpenSpielObservation): return _execute_strategy(strategy, current_state) # ============================================================================= # CELL 14: Code safety - check_python_modules (anti-reward-hacking) # ============================================================================= from unsloth import check_python_modules # Example: allowed (only stdlib) sample_ok = """ def strategy(board): import math from typing import Callable return "0" """ ok, info = check_python_modules(sample_ok) print("Only Python imports?", ok) # True print(info) # Example: disallowed (numpy is non-stdlib) sample_bad = """ def strategy(board): from numpy import matmul return "0" """ ok, info = check_python_modules(sample_bad) print("Only Python imports?", ok) # False print(info) # ============================================================================= # CELL 15: Sandboxed function execution (no global variable leakage) # ============================================================================= from unsloth import create_locked_down_function # This will fail - np is not defined inside the sandbox function_bad = """ def import_numpy(): np.matmul print("Success") """ f = create_locked_down_function(function_bad) try: f() except Exception as e: print(str(e)) # "name 'np' is not defined" # This will work - no external references function_good = """ def add(a, b): def adder(a): return a + b return adder(b) + b """ f = create_locked_down_function(function_good) try: print(f(10, 20)) # 60 except Exception as e: print(str(e)) # ============================================================================= # CELL 16: THE PROMPT - How the LLM interacts with the environment # ============================================================================= prompt = """ Create a new short 2048 strategy using only native Python code. You are given a list of list of numbers for the current board state. Output one action for "0", "1", "2", "3" on what is the optimal next step. Output your new short function in backticks using the format below: ```python def strategy(board): return "0" # Example ``` All helper functions should be inside def strategy. Only output the short function `strategy`. """.strip() print(prompt) # ============================================================================= # CELL 17: Test inference before RL training # ============================================================================= text = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize = False, add_generation_prompt = True, reasoning_effort = "low", ) from transformers import TextStreamer _ = model.generate( **tokenizer(text, return_tensors="pt").to("cuda"), temperature = 1.0, max_new_tokens = 512, streamer = TextStreamer(tokenizer, skip_prompt=False), ) # ============================================================================= # CELL 18: REWARD FUNCTION 1 - extract_function (helper) # ============================================================================= def extract_function(text): """ Extract a Python function wrapped in triple backticks from the LLM response. Returns the function source string, or None if not found. """ if text.count("```") >= 2: first = text.find("```") + 3 second = text.find("```", first) fx = text[first:second].strip() fx = fx.removeprefix("python\n") fx = fx[fx.find("def"):] if fx.startswith("def strategy(board):"): return fx return None # ============================================================================= # CELL 19: REWARD FUNCTION 2 - function_works # ============================================================================= def function_works(completions, **kwargs): """ Reward: Does the generated code parse as valid Python and compile? +1.0 if valid function that can be created -0.5 if it has the right structure but exec fails -2.0 if no function extracted or syntax error """ scores = [] for completion in completions: score = 0 response = completion[0]["content"] function = extract_function(response) if function is not None: ok, info = check_python_modules(function) if function is None or "error" in info: score = -2.0 else: try: new_strategy = create_locked_down_function(function) score = 1.0 except: score = -0.5 scores.append(score) return scores # ============================================================================= # CELL 20: REWARD FUNCTION 3 - no_cheating # ============================================================================= def no_cheating(completions, **kwargs): """ Reward: Does the function only use Python stdlib imports? +1.0 if only stdlib imports -20.0 if non-stdlib imports (heavy penalty!) -1.0 if function extraction failed """ scores = [] for completion in completions: score = 0 response = completion[0]["content"] function = extract_function(response) if function is not None: ok, info = check_python_modules(function) scores.append(1.0 if ok else -20.0) # Penalize heavily! else: scores.append(-1.0) # Failed creating function return scores # ============================================================================= # CELL 21: REWARD FUNCTION 4 - strategy_succeeds # ============================================================================= import numpy as np global PRINTER PRINTER = 0 def strategy_succeeds(completions, **kwargs): """ Reward: Does the strategy actually play 2048 successfully? +20.0 if the strategy reaches 2048 (massive reward!) +2.0 if the function runs and plays but doesn't reach 2048 -1.0 if timeout (strategy takes too long) -3.0 if exception during execution 0 if function is broken/can't be created """ global PRINTER scores = [] for completion in completions: printed = False score = 0 response = completion[0]["content"] function = extract_function(response) if PRINTER % 5 == 0: printed = True print(function) PRINTER += 1 if function is not None: ok, info = check_python_modules(function) if function is None or "error" in info: scores.append(0) continue try: new_strategy = create_locked_down_function(function) except: scores.append(0) continue try: # Reset OpenEnv to an initial state! global port, openenv_process port, openenv_process = launch_openenv(port, openenv_process) result = openenv_process.reset() current_state = result.observation steps, if_done = execute_strategy(new_strategy, current_state) print(f"Steps = {steps} If Done = {if_done}") if printed is False: print(function) print(render_board(current_state)) if if_done: scores.append(20.0) # Success - massively reward! else: scores.append(2.0) # Failed but function works! except TimeoutError as e: print("Timeout") scores.append(-1.0) # Failed with timeout except Exception as e: print(f"Exception = {str(e)}") scores.append(-3.0) # Failed return scores # ============================================================================= # CELL 22: Create the dataset (replicated prompt) # ============================================================================= from datasets import Dataset dataset = Dataset.from_list([ { "prompt": [{"role": "user", "content": prompt.strip()}], "answer": 0, "reasoning_effort": "low", } ] * 1000) maximum_length = len(tokenizer.apply_chat_template( [{"role": "user", "content": prompt.strip()}], add_generation_prompt=True, )) print(f"Prompt token length: {maximum_length}") # ============================================================================= # CELL 23: GRPO Training Configuration # ============================================================================= max_prompt_length = maximum_length + 1 # + 1 just in case! max_completion_length = max_seq_length - max_prompt_length from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( temperature = 1.0, learning_rate = 2e-4, weight_decay = 0.001, warmup_ratio = 0.1, lr_scheduler_type = "linear", optim = "adamw_8bit", logging_steps = 1, per_device_train_batch_size = 1, gradient_accumulation_steps = 1, # Increase to 4 for smoother training num_generations = 2, # Decrease if out of memory max_prompt_length = max_prompt_length, max_completion_length = max_completion_length, # num_train_epochs = 1, # Set to 1 for a full training run max_steps = 600, save_steps = 100, report_to = "trackio", # Can use Weights & Biases, TrackIO output_dir = "outputs", ) # ============================================================================= # CELL 24: Create GRPO Trainer and Train # ============================================================================= trainer = GRPOTrainer( model = model, processing_class = tokenizer, reward_funcs = [ function_works, # Reward 1: Is it valid Python? no_cheating, # Reward 2: Only stdlib imports? strategy_succeeds, # Reward 3: Does it actually play 2048? ], args = training_args, train_dataset = dataset, ) # Start training! (~5 hours for 600 steps on T4) # Expect 0 reward for ~first 100 steps, then gradual improvement trainer.train() # ============================================================================= # CELL 25: Inference after RL training # ============================================================================= text = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize = False, add_generation_prompt = True, reasoning_effort = "low", ) from transformers import TextStreamer _ = model.generate( **tokenizer(text, return_tensors="pt").to("cuda"), temperature = 1.0, max_new_tokens = 1024, streamer = TextStreamer(tokenizer, skip_prompt=False), ) # ============================================================================= # CELL 26: Save the model (optional) # ============================================================================= # Merge and push to hub in mxfp4 4bit format if False: model.save_pretrained_merged("finetuned_model", tokenizer, save_method="mxfp4") if False: model.push_to_hub_merged("repo_id/repo_name", tokenizer, token="hf...", save_method="mxfp4") # Merge and push to hub in 16bit if False: model.save_pretrained_merged("finetuned_model", tokenizer, save_method="merged_16bit") if False: model.push_to_hub_merged("hf/gpt-oss-finetune", tokenizer, save_method="merged_16bit", token="")