File size: 2,727 Bytes
1938d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Action tokenizer for ByteFight policy model.

Action vocab (21 tokens):
  0-3:   Move REGULAR (UP, DOWN, LEFT, RIGHT)
  4-7:   Move REGULAR + PLACE_BEACON
  8-11:  Move ERASE
  12-15: Move BEACON_TRAVEL
  16-19: Paint direction (UP, DOWN, LEFT, RIGHT)
  20:    EOS (end of turn)

Board vocab: 2462 tokens (from existing tokenizer)
Action tokens offset by BOARD_VOCAB_SIZE in the shared embedding.
"""
BOARD_VOCAB_SIZE = 2462
NUM_ACTION_TOKENS = 21
TOTAL_VOCAB_SIZE = BOARD_VOCAB_SIZE + NUM_ACTION_TOKENS  # 2483

BOARD_SEQ_LEN = 970       # 9 scalars + 961 cells (no CLS)
MAX_ACTIONS = 10           # max actions per turn (from data analysis)
SEQ_LEN = BOARD_SEQ_LEN + MAX_ACTIONS  # 980

EOS_TOKEN = 20             # local
EOS_GLOBAL = BOARD_VOCAB_SIZE + EOS_TOKEN  # 2482
IGNORE_INDEX = -100        # for loss masking

DIR_DR = [-1, 1, 0, 0]
DIR_DC = [0, 0, -1, 1]
DIR_MAP = {'UP': 0, 'DOWN': 1, 'LEFT': 2, 'RIGHT': 3}
MT_MAP = {'REGULAR': 0, 'ERASE': 1, 'BEACON_TRAVEL': 2}


def encode_action(action_dict, player_r=0, player_c=0):
    atype = action_dict.get('type', action_dict.get('name', ''))
    if atype in ('move', 'Move'):
        d = action_dict.get('direction', 0)
        if isinstance(d, str): d = DIR_MAP.get(d, 0)
        mt = action_dict.get('move_type', 0)
        if isinstance(mt, str): mt = MT_MAP.get(mt, 0)
        pb = action_dict.get('place_beacon', False)
        if mt == 2: return 12 + d
        elif mt == 1: return 8 + d
        elif pb: return 4 + d
        else: return d
    elif atype in ('paint', 'Paint'):
        loc = action_dict.get('location', None)
        if loc: pr, pc = int(loc[0]), int(loc[1])
        else: pr, pc = int(action_dict.get('r', 0)), int(action_dict.get('c', 0))
        dr, dc = pr - player_r, pc - player_c
        for i in range(4):
            if DIR_DR[i] == dr and DIR_DC[i] == dc: return 16 + i
        return -1
    return -1


def decode_action(token_id, player_r=0, player_c=0):
    if token_id == EOS_TOKEN: return {'type': 'eos'}
    elif token_id < 4: return {'type': 'move', 'direction': token_id, 'move_type': 0, 'place_beacon': False}
    elif token_id < 8: return {'type': 'move', 'direction': token_id - 4, 'move_type': 0, 'place_beacon': True}
    elif token_id < 12: return {'type': 'move', 'direction': token_id - 8, 'move_type': 1, 'place_beacon': False}
    elif token_id < 16: return {'type': 'move', 'direction': token_id - 12, 'move_type': 2, 'place_beacon': False}
    elif token_id < 20: d = token_id - 16; return {'type': 'paint', 'r': player_r + DIR_DR[d], 'c': player_c + DIR_DC[d]}
    return {'type': 'eos'}


def action_to_global(local): return BOARD_VOCAB_SIZE + local
def global_to_action(glob): return glob - BOARD_VOCAB_SIZE