File size: 9,779 Bytes
665fb96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
Tokenizer for ByteFight game states.

Converts raw game log data into token sequences for a value network.

Sequence layout (971 tokens):
  [CLS] [my_stam] [my_max_stam] [opp_stam] [opp_max_stam]
  [my_row] [my_col] [opp_row] [opp_col] [turn]
  [cell_0_0] [cell_0_1] ... [cell_30_30]

All token IDs live in a single unified vocabulary:
  0:          CLS
  1-381:      stamina (0-380)
  382-412:    position (0-30)
  413-2413:   turn (0-2000)
  2414-2458:  cell states (45 tokens)

Total vocab size: 2459
"""

import json
import numpy as np
from pathlib import Path
from dataclasses import dataclass

MAX_BOARD_SIZE = 31
MAX_STAMINA = 380
MAX_TURN = 2000
GRID_CELLS = MAX_BOARD_SIZE * MAX_BOARD_SIZE  # 961
SEQ_LEN = 1 + 9 + GRID_CELLS  # 971

# Global token offsets
CLS_TOKEN = 0
STAMINA_OFFSET = 1                                  # 1-381
POSITION_OFFSET = STAMINA_OFFSET + MAX_STAMINA + 1  # 382-412
TURN_OFFSET = POSITION_OFFSET + MAX_BOARD_SIZE      # 413-2413
CELL_OFFSET = TURN_OFFSET + MAX_TURN + 1            # 2414-2458

# Cell state tokens (local, before adding CELL_OFFSET)
CELL_WALL = 0
CELL_EMPTY = 1
CELL_P1_PAINT_1 = 2
CELL_P1_PAINT_4 = 5
CELL_P2_PAINT_1 = 6
CELL_P2_PAINT_4 = 9
CELL_P1_BEACON = 10
CELL_P2_BEACON = 11
NUM_BASE_CELL_STATES = 12

# hill/powerup offsets (local)
HILL_OFFSET_LOCAL = NUM_BASE_CELL_STATES           # +12
POWERUP_OFFSET_LOCAL = 2 * NUM_BASE_CELL_STATES    # +24
HILL_POWERUP_OFFSET_LOCAL = 3 * NUM_BASE_CELL_STATES  # +36

# Max local cell token: CELL_P2_BEACON (11) + HILL_POWERUP_OFFSET_LOCAL (36) = 47
NUM_CELL_TOKENS = CELL_P2_BEACON + HILL_POWERUP_OFFSET_LOCAL + 1  # 48

VOCAB_SIZE = CELL_OFFSET + NUM_CELL_TOKENS  # 2462


def _encode_cell(paint_value: int, beacon_parity: int, is_wall: bool,
                 hill: bool, powerup: bool) -> int:
    """Encode a single cell's state into a global token ID."""
    if is_wall:
        return CELL_OFFSET + CELL_WALL

    if beacon_parity == 1:
        base = CELL_P1_BEACON
    elif beacon_parity == -1:
        base = CELL_P2_BEACON
    elif paint_value > 0:
        base = CELL_P1_PAINT_1 + (paint_value - 1)
    elif paint_value < 0:
        base = CELL_P2_PAINT_1 + (-paint_value - 1)
    else:
        base = CELL_EMPTY

    if hill and powerup:
        local = base + HILL_POWERUP_OFFSET_LOCAL
    elif hill:
        local = base + HILL_OFFSET_LOCAL
    elif powerup:
        local = base + POWERUP_OFFSET_LOCAL
    else:
        local = base

    return CELL_OFFSET + local


@dataclass
class TokenizedState:
    """A single tokenized game state with its label."""
    tokens: np.ndarray   # (971,) int32
    label: float         # 1.0 = p1 wins, 0.0 = p2 wins


def _parse_map(map_string: str):
    """Parse static map info (size, walls, hills) from a map string."""
    parts = map_string.split("#")
    size_r, size_c = int(parts[0].split(",")[0]), int(parts[0].split(",")[1])

    walls = set()
    for i, ch in enumerate(parts[3]):
        if ch == "1":
            walls.add((i // size_c, i % size_c))

    hill_cells = set()
    hill_ids_str = parts[4]
    hill_sets_str = parts[5]
    if hill_ids_str.strip():
        hill_id_list = [x for x in hill_ids_str.split(",") if x.strip()]
        hill_set_list = hill_sets_str.split("_")
        for i, _ in enumerate(hill_id_list):
            if i < len(hill_set_list) and hill_set_list[i]:
                coords = hill_set_list[i].split(",")
                for j in range(len(coords) // 2):
                    r, c = int(coords[2 * j]), int(coords[2 * j + 1])
                    hill_cells.add((r, c))

    return size_r, size_c, walls, hill_cells


def _replay_deltas(gl: dict, size_r: int, size_c: int, up_to: int):
    """Replay paint/beacon/powerup deltas from turn 0 up to (inclusive)."""
    paint = np.zeros((size_r, size_c), dtype=np.int8)
    beacon = np.zeros((size_r, size_c), dtype=np.int8)
    powerup = np.zeros((size_r, size_c), dtype=np.bool_)

    for t in range(up_to + 1):
        for cell_key, value in gl["paint_updates"][t].items():
            idx = int(cell_key)
            paint[idx // size_c, idx % size_c] = value
        for cell_key, value in gl["beacon_updates"][t].items():
            idx = int(cell_key)
            beacon[idx // size_c, idx % size_c] = value
        for cell_key, value in gl["powerup_updates"][t].items():
            idx = int(cell_key)
            powerup[idx // size_c, idx % size_c] = value

    return paint, beacon, powerup


def _build_tokens(gl: dict, turn_idx: int, size_r: int, size_c: int,
                  walls: set, hill_cells: set,
                  paint: np.ndarray, beacon: np.ndarray,
                  powerup: np.ndarray) -> np.ndarray:
    """Build the 971-token sequence for a single turn."""
    tokens = np.zeros(SEQ_LEN, dtype=np.int32)

    tokens[0] = CLS_TOKEN
    tokens[1] = STAMINA_OFFSET + min(max(gl["p1_stamina"][turn_idx], 0), MAX_STAMINA)
    tokens[2] = STAMINA_OFFSET + min(max(gl["p1_max_stamina"][turn_idx], 0), MAX_STAMINA)
    tokens[3] = STAMINA_OFFSET + min(max(gl["p2_stamina"][turn_idx], 0), MAX_STAMINA)
    tokens[4] = STAMINA_OFFSET + min(max(gl["p2_max_stamina"][turn_idx], 0), MAX_STAMINA)
    tokens[5] = POSITION_OFFSET + gl["p1_loc"][turn_idx][0]
    tokens[6] = POSITION_OFFSET + gl["p1_loc"][turn_idx][1]
    tokens[7] = POSITION_OFFSET + gl["p2_loc"][turn_idx][0]
    tokens[8] = POSITION_OFFSET + gl["p2_loc"][turn_idx][1]
    tokens[9] = TURN_OFFSET + min(turn_idx, MAX_TURN)

    for r in range(MAX_BOARD_SIZE):
        for c in range(MAX_BOARD_SIZE):
            grid_idx = 10 + r * MAX_BOARD_SIZE + c
            if r >= size_r or c >= size_c:
                tokens[grid_idx] = CELL_OFFSET + CELL_WALL
            else:
                tokens[grid_idx] = _encode_cell(
                    paint_value=int(paint[r, c]),
                    beacon_parity=int(beacon[r, c]),
                    is_wall=(r, c) in walls,
                    hill=(r, c) in hill_cells,
                    powerup=bool(powerup[r, c]),
                )

    return tokens


def _parse_label(result: str) -> float:
    if result == "PLAYER_1":
        return 1.0
    elif result == "PLAYER_2":
        return 0.0
    return 0.5


def tokenize_turn(gl: dict, map_string: str, turn_idx: int) -> np.ndarray:
    """Tokenize a single turn from a game log. Returns (971,) int32 array."""
    size_r, size_c, walls, hill_cells = _parse_map(map_string)
    paint, beacon, powerup = _replay_deltas(gl, size_r, size_c, turn_idx)
    return _build_tokens(gl, turn_idx, size_r, size_c, walls, hill_cells,
                         paint, beacon, powerup)


def tokenize_match(match_path: str | Path) -> list[TokenizedState]:
    """
    Tokenize all turns of a match into training examples.

    Returns examples from P1's perspective. To get P2's perspective,
    the caller can use flip_perspective().
    """
    with open(match_path) as f:
        data = json.load(f)

    gl = data["game_log"]
    size_r, size_c, walls, hill_cells = _parse_map(gl["map_string"])
    label = _parse_label(gl["result"])
    num_turns = len(gl["p1_stamina"])

    # Incrementally replay deltas (more efficient than replaying from 0 each time)
    paint = np.zeros((size_r, size_c), dtype=np.int8)
    beacon = np.zeros((size_r, size_c), dtype=np.int8)
    powerup = np.zeros((size_r, size_c), dtype=np.bool_)

    examples = []

    for t in range(num_turns):
        for cell_key, value in gl["paint_updates"][t].items():
            idx = int(cell_key)
            paint[idx // size_c, idx % size_c] = value
        for cell_key, value in gl["beacon_updates"][t].items():
            idx = int(cell_key)
            beacon[idx // size_c, idx % size_c] = value
        for cell_key, value in gl["powerup_updates"][t].items():
            idx = int(cell_key)
            powerup[idx // size_c, idx % size_c] = value

        tokens = _build_tokens(gl, t, size_r, size_c, walls, hill_cells,
                               paint, beacon, powerup)
        examples.append(TokenizedState(tokens=tokens, label=label))

    return examples


def flip_perspective(state: TokenizedState) -> TokenizedState:
    """
    Flip a tokenized state from P1's perspective to P2's perspective.

    Swaps player stamina/position scalars and flips cell ownership
    (P1 paint <-> P2 paint, P1 beacon <-> P2 beacon).
    """
    tokens = state.tokens.copy()

    # Swap stamina: (1,2) <-> (3,4)
    tokens[1], tokens[3] = tokens[3], tokens[1]
    tokens[2], tokens[4] = tokens[4], tokens[2]

    # Swap positions: (5,6) <-> (7,8)
    tokens[5], tokens[7] = tokens[7], tokens[5]
    tokens[6], tokens[8] = tokens[8], tokens[6]

    # Flip cell ownership in grid
    for i in range(10, SEQ_LEN):
        cell = tokens[i] - CELL_OFFSET
        if cell == CELL_WALL:
            continue

        if cell >= HILL_POWERUP_OFFSET_LOCAL:
            offset = HILL_POWERUP_OFFSET_LOCAL
            base = cell - HILL_POWERUP_OFFSET_LOCAL
        elif cell >= POWERUP_OFFSET_LOCAL:
            offset = POWERUP_OFFSET_LOCAL
            base = cell - POWERUP_OFFSET_LOCAL
        elif cell >= HILL_OFFSET_LOCAL:
            offset = HILL_OFFSET_LOCAL
            base = cell - HILL_OFFSET_LOCAL
        else:
            offset = 0
            base = cell

        if CELL_P1_PAINT_1 <= base <= CELL_P1_PAINT_4:
            base = CELL_P2_PAINT_1 + (base - CELL_P1_PAINT_1)
        elif CELL_P2_PAINT_1 <= base <= CELL_P2_PAINT_4:
            base = CELL_P1_PAINT_1 + (base - CELL_P2_PAINT_1)
        elif base == CELL_P1_BEACON:
            base = CELL_P2_BEACON
        elif base == CELL_P2_BEACON:
            base = CELL_P1_BEACON

        tokens[i] = CELL_OFFSET + base + offset

    label = 1.0 - state.label

    return TokenizedState(tokens=tokens, label=label)