File size: 5,202 Bytes
1fa3c6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl[peft]",
# ]
# ///
import random
from datasets import Dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
PROMPT = "Play 2048 on a 4x4 board. Use the tool `move` with one of: up, down, left, right. Maximize the score."
class Game2048Env:
def reset(self, **kwargs) -> str:
self.board = [[0] * 4 for _ in range(4)]
self.score = 0.0
self.done = False
self._spawn()
self._spawn()
return f"score={self.score}\n{self._render()}\ndone={self.done}"
def move(self, direction: str) -> str:
"""
Play one move in 2048.
Args:
direction: One of "up", "down", "left", "right".
Returns:
Environment feedback after the move.
"""
if self.done:
raise ValueError("Game over.")
moved, gained = self._apply_move(direction.strip().lower())
if moved:
self.score += gained
self._spawn()
self.done = not self._can_move()
return f"score={self.score}\n{self._render()}\ndone={self.done}"
def _spawn(self) -> None:
empty = [(r, c) for r in range(4) for c in range(4) if self.board[r][c] == 0]
if not empty:
return
r, c = random.choice(empty)
self.board[r][c] = 4 if random.random() < 0.1 else 2
@staticmethod
def _merge_line(line: list[int]) -> tuple[list[int], int]:
vals = [x for x in line if x]
out = []
gained = 0
i = 0
while i < len(vals):
if i + 1 < len(vals) and vals[i] == vals[i + 1]:
v = vals[i] * 2
out.append(v)
gained += v
i += 2
else:
out.append(vals[i])
i += 1
out += [0] * (4 - len(out))
return out, gained
def _apply_move(self, direction: str) -> tuple[bool, int]:
if direction not in {"up", "down", "left", "right"}:
return False, 0
before = [row[:] for row in self.board]
gained_total = 0
if direction in {"left", "right"}:
for r in range(4):
row = self.board[r][:]
if direction == "right":
row.reverse()
merged, gained = self._merge_line(row)
if direction == "right":
merged.reverse()
self.board[r] = merged
gained_total += gained
else:
for c in range(4):
col = [self.board[r][c] for r in range(4)]
if direction == "down":
col.reverse()
merged, gained = self._merge_line(col)
if direction == "down":
merged.reverse()
for r in range(4):
self.board[r][c] = merged[r]
gained_total += gained
moved = self.board != before
return moved, gained_total
def _can_move(self) -> bool:
if any(0 in row for row in self.board):
return True
for r in range(4):
for c in range(4):
if r + 1 < 4 and self.board[r][c] == self.board[r + 1][c]:
return True
if c + 1 < 4 and self.board[r][c] == self.board[r][c + 1]:
return True
return False
def _render(self) -> str:
return "\n".join(" ".join(f"{v:3d}" for v in row) for row in self.board)
def reward_score(environments, **kwargs):
return [env.score for env in environments]
def main() -> None:
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": PROMPT}] for _ in range(1000)]})
trainer = GRPOTrainer(
model="Qwen/Qwen3-4B",
train_dataset=dataset,
reward_funcs=reward_score,
args=GRPOConfig(
chat_template_kwargs={"enable_thinking": False},
logging_steps=1,
log_completions=True,
num_completions_to_print=2,
report_to="trackio",
trackio_space_id="trl-2048",
max_completion_length=2048,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
),
environment_factory=Game2048Env,
peft_config=LoraConfig(),
)
trainer.train()
if __name__ == "__main__":
main()
|