bytefight-policy / tokenizer.py
Broyojo's picture
Upload tokenizer.py with huggingface_hub
665fb96 verified
"""
Tokenizer for ByteFight game states.
Converts raw game log data into token sequences for a value network.
Sequence layout (971 tokens):
[CLS] [my_stam] [my_max_stam] [opp_stam] [opp_max_stam]
[my_row] [my_col] [opp_row] [opp_col] [turn]
[cell_0_0] [cell_0_1] ... [cell_30_30]
All token IDs live in a single unified vocabulary:
0: CLS
1-381: stamina (0-380)
382-412: position (0-30)
413-2413: turn (0-2000)
2414-2458: cell states (45 tokens)
Total vocab size: 2459
"""
import json
import numpy as np
from pathlib import Path
from dataclasses import dataclass
MAX_BOARD_SIZE = 31
MAX_STAMINA = 380
MAX_TURN = 2000
GRID_CELLS = MAX_BOARD_SIZE * MAX_BOARD_SIZE # 961
SEQ_LEN = 1 + 9 + GRID_CELLS # 971
# Global token offsets
CLS_TOKEN = 0
STAMINA_OFFSET = 1 # 1-381
POSITION_OFFSET = STAMINA_OFFSET + MAX_STAMINA + 1 # 382-412
TURN_OFFSET = POSITION_OFFSET + MAX_BOARD_SIZE # 413-2413
CELL_OFFSET = TURN_OFFSET + MAX_TURN + 1 # 2414-2458
# Cell state tokens (local, before adding CELL_OFFSET)
CELL_WALL = 0
CELL_EMPTY = 1
CELL_P1_PAINT_1 = 2
CELL_P1_PAINT_4 = 5
CELL_P2_PAINT_1 = 6
CELL_P2_PAINT_4 = 9
CELL_P1_BEACON = 10
CELL_P2_BEACON = 11
NUM_BASE_CELL_STATES = 12
# hill/powerup offsets (local)
HILL_OFFSET_LOCAL = NUM_BASE_CELL_STATES # +12
POWERUP_OFFSET_LOCAL = 2 * NUM_BASE_CELL_STATES # +24
HILL_POWERUP_OFFSET_LOCAL = 3 * NUM_BASE_CELL_STATES # +36
# Max local cell token: CELL_P2_BEACON (11) + HILL_POWERUP_OFFSET_LOCAL (36) = 47
NUM_CELL_TOKENS = CELL_P2_BEACON + HILL_POWERUP_OFFSET_LOCAL + 1 # 48
VOCAB_SIZE = CELL_OFFSET + NUM_CELL_TOKENS # 2462
def _encode_cell(paint_value: int, beacon_parity: int, is_wall: bool,
hill: bool, powerup: bool) -> int:
"""Encode a single cell's state into a global token ID."""
if is_wall:
return CELL_OFFSET + CELL_WALL
if beacon_parity == 1:
base = CELL_P1_BEACON
elif beacon_parity == -1:
base = CELL_P2_BEACON
elif paint_value > 0:
base = CELL_P1_PAINT_1 + (paint_value - 1)
elif paint_value < 0:
base = CELL_P2_PAINT_1 + (-paint_value - 1)
else:
base = CELL_EMPTY
if hill and powerup:
local = base + HILL_POWERUP_OFFSET_LOCAL
elif hill:
local = base + HILL_OFFSET_LOCAL
elif powerup:
local = base + POWERUP_OFFSET_LOCAL
else:
local = base
return CELL_OFFSET + local
@dataclass
class TokenizedState:
"""A single tokenized game state with its label."""
tokens: np.ndarray # (971,) int32
label: float # 1.0 = p1 wins, 0.0 = p2 wins
def _parse_map(map_string: str):
"""Parse static map info (size, walls, hills) from a map string."""
parts = map_string.split("#")
size_r, size_c = int(parts[0].split(",")[0]), int(parts[0].split(",")[1])
walls = set()
for i, ch in enumerate(parts[3]):
if ch == "1":
walls.add((i // size_c, i % size_c))
hill_cells = set()
hill_ids_str = parts[4]
hill_sets_str = parts[5]
if hill_ids_str.strip():
hill_id_list = [x for x in hill_ids_str.split(",") if x.strip()]
hill_set_list = hill_sets_str.split("_")
for i, _ in enumerate(hill_id_list):
if i < len(hill_set_list) and hill_set_list[i]:
coords = hill_set_list[i].split(",")
for j in range(len(coords) // 2):
r, c = int(coords[2 * j]), int(coords[2 * j + 1])
hill_cells.add((r, c))
return size_r, size_c, walls, hill_cells
def _replay_deltas(gl: dict, size_r: int, size_c: int, up_to: int):
"""Replay paint/beacon/powerup deltas from turn 0 up to (inclusive)."""
paint = np.zeros((size_r, size_c), dtype=np.int8)
beacon = np.zeros((size_r, size_c), dtype=np.int8)
powerup = np.zeros((size_r, size_c), dtype=np.bool_)
for t in range(up_to + 1):
for cell_key, value in gl["paint_updates"][t].items():
idx = int(cell_key)
paint[idx // size_c, idx % size_c] = value
for cell_key, value in gl["beacon_updates"][t].items():
idx = int(cell_key)
beacon[idx // size_c, idx % size_c] = value
for cell_key, value in gl["powerup_updates"][t].items():
idx = int(cell_key)
powerup[idx // size_c, idx % size_c] = value
return paint, beacon, powerup
def _build_tokens(gl: dict, turn_idx: int, size_r: int, size_c: int,
walls: set, hill_cells: set,
paint: np.ndarray, beacon: np.ndarray,
powerup: np.ndarray) -> np.ndarray:
"""Build the 971-token sequence for a single turn."""
tokens = np.zeros(SEQ_LEN, dtype=np.int32)
tokens[0] = CLS_TOKEN
tokens[1] = STAMINA_OFFSET + min(max(gl["p1_stamina"][turn_idx], 0), MAX_STAMINA)
tokens[2] = STAMINA_OFFSET + min(max(gl["p1_max_stamina"][turn_idx], 0), MAX_STAMINA)
tokens[3] = STAMINA_OFFSET + min(max(gl["p2_stamina"][turn_idx], 0), MAX_STAMINA)
tokens[4] = STAMINA_OFFSET + min(max(gl["p2_max_stamina"][turn_idx], 0), MAX_STAMINA)
tokens[5] = POSITION_OFFSET + gl["p1_loc"][turn_idx][0]
tokens[6] = POSITION_OFFSET + gl["p1_loc"][turn_idx][1]
tokens[7] = POSITION_OFFSET + gl["p2_loc"][turn_idx][0]
tokens[8] = POSITION_OFFSET + gl["p2_loc"][turn_idx][1]
tokens[9] = TURN_OFFSET + min(turn_idx, MAX_TURN)
for r in range(MAX_BOARD_SIZE):
for c in range(MAX_BOARD_SIZE):
grid_idx = 10 + r * MAX_BOARD_SIZE + c
if r >= size_r or c >= size_c:
tokens[grid_idx] = CELL_OFFSET + CELL_WALL
else:
tokens[grid_idx] = _encode_cell(
paint_value=int(paint[r, c]),
beacon_parity=int(beacon[r, c]),
is_wall=(r, c) in walls,
hill=(r, c) in hill_cells,
powerup=bool(powerup[r, c]),
)
return tokens
def _parse_label(result: str) -> float:
if result == "PLAYER_1":
return 1.0
elif result == "PLAYER_2":
return 0.0
return 0.5
def tokenize_turn(gl: dict, map_string: str, turn_idx: int) -> np.ndarray:
"""Tokenize a single turn from a game log. Returns (971,) int32 array."""
size_r, size_c, walls, hill_cells = _parse_map(map_string)
paint, beacon, powerup = _replay_deltas(gl, size_r, size_c, turn_idx)
return _build_tokens(gl, turn_idx, size_r, size_c, walls, hill_cells,
paint, beacon, powerup)
def tokenize_match(match_path: str | Path) -> list[TokenizedState]:
"""
Tokenize all turns of a match into training examples.
Returns examples from P1's perspective. To get P2's perspective,
the caller can use flip_perspective().
"""
with open(match_path) as f:
data = json.load(f)
gl = data["game_log"]
size_r, size_c, walls, hill_cells = _parse_map(gl["map_string"])
label = _parse_label(gl["result"])
num_turns = len(gl["p1_stamina"])
# Incrementally replay deltas (more efficient than replaying from 0 each time)
paint = np.zeros((size_r, size_c), dtype=np.int8)
beacon = np.zeros((size_r, size_c), dtype=np.int8)
powerup = np.zeros((size_r, size_c), dtype=np.bool_)
examples = []
for t in range(num_turns):
for cell_key, value in gl["paint_updates"][t].items():
idx = int(cell_key)
paint[idx // size_c, idx % size_c] = value
for cell_key, value in gl["beacon_updates"][t].items():
idx = int(cell_key)
beacon[idx // size_c, idx % size_c] = value
for cell_key, value in gl["powerup_updates"][t].items():
idx = int(cell_key)
powerup[idx // size_c, idx % size_c] = value
tokens = _build_tokens(gl, t, size_r, size_c, walls, hill_cells,
paint, beacon, powerup)
examples.append(TokenizedState(tokens=tokens, label=label))
return examples
def flip_perspective(state: TokenizedState) -> TokenizedState:
"""
Flip a tokenized state from P1's perspective to P2's perspective.
Swaps player stamina/position scalars and flips cell ownership
(P1 paint <-> P2 paint, P1 beacon <-> P2 beacon).
"""
tokens = state.tokens.copy()
# Swap stamina: (1,2) <-> (3,4)
tokens[1], tokens[3] = tokens[3], tokens[1]
tokens[2], tokens[4] = tokens[4], tokens[2]
# Swap positions: (5,6) <-> (7,8)
tokens[5], tokens[7] = tokens[7], tokens[5]
tokens[6], tokens[8] = tokens[8], tokens[6]
# Flip cell ownership in grid
for i in range(10, SEQ_LEN):
cell = tokens[i] - CELL_OFFSET
if cell == CELL_WALL:
continue
if cell >= HILL_POWERUP_OFFSET_LOCAL:
offset = HILL_POWERUP_OFFSET_LOCAL
base = cell - HILL_POWERUP_OFFSET_LOCAL
elif cell >= POWERUP_OFFSET_LOCAL:
offset = POWERUP_OFFSET_LOCAL
base = cell - POWERUP_OFFSET_LOCAL
elif cell >= HILL_OFFSET_LOCAL:
offset = HILL_OFFSET_LOCAL
base = cell - HILL_OFFSET_LOCAL
else:
offset = 0
base = cell
if CELL_P1_PAINT_1 <= base <= CELL_P1_PAINT_4:
base = CELL_P2_PAINT_1 + (base - CELL_P1_PAINT_1)
elif CELL_P2_PAINT_1 <= base <= CELL_P2_PAINT_4:
base = CELL_P1_PAINT_1 + (base - CELL_P2_PAINT_1)
elif base == CELL_P1_BEACON:
base = CELL_P2_BEACON
elif base == CELL_P2_BEACON:
base = CELL_P1_BEACON
tokens[i] = CELL_OFFSET + base + offset
label = 1.0 - state.label
return TokenizedState(tokens=tokens, label=label)