""" Action tokenizer for ByteFight policy model. Action vocab (21 tokens): 0-3: Move REGULAR (UP, DOWN, LEFT, RIGHT) 4-7: Move REGULAR + PLACE_BEACON 8-11: Move ERASE 12-15: Move BEACON_TRAVEL 16-19: Paint direction (UP, DOWN, LEFT, RIGHT) 20: EOS (end of turn) Board vocab: 2462 tokens (from existing tokenizer) Action tokens offset by BOARD_VOCAB_SIZE in the shared embedding. """ BOARD_VOCAB_SIZE = 2462 NUM_ACTION_TOKENS = 21 TOTAL_VOCAB_SIZE = BOARD_VOCAB_SIZE + NUM_ACTION_TOKENS # 2483 BOARD_SEQ_LEN = 970 # 9 scalars + 961 cells (no CLS) MAX_ACTIONS = 10 # max actions per turn (from data analysis) SEQ_LEN = BOARD_SEQ_LEN + MAX_ACTIONS # 980 EOS_TOKEN = 20 # local EOS_GLOBAL = BOARD_VOCAB_SIZE + EOS_TOKEN # 2482 IGNORE_INDEX = -100 # for loss masking DIR_DR = [-1, 1, 0, 0] DIR_DC = [0, 0, -1, 1] DIR_MAP = {'UP': 0, 'DOWN': 1, 'LEFT': 2, 'RIGHT': 3} MT_MAP = {'REGULAR': 0, 'ERASE': 1, 'BEACON_TRAVEL': 2} def encode_action(action_dict, player_r=0, player_c=0): atype = action_dict.get('type', action_dict.get('name', '')) if atype in ('move', 'Move'): d = action_dict.get('direction', 0) if isinstance(d, str): d = DIR_MAP.get(d, 0) mt = action_dict.get('move_type', 0) if isinstance(mt, str): mt = MT_MAP.get(mt, 0) pb = action_dict.get('place_beacon', False) if mt == 2: return 12 + d elif mt == 1: return 8 + d elif pb: return 4 + d else: return d elif atype in ('paint', 'Paint'): loc = action_dict.get('location', None) if loc: pr, pc = int(loc[0]), int(loc[1]) else: pr, pc = int(action_dict.get('r', 0)), int(action_dict.get('c', 0)) dr, dc = pr - player_r, pc - player_c for i in range(4): if DIR_DR[i] == dr and DIR_DC[i] == dc: return 16 + i return -1 return -1 def decode_action(token_id, player_r=0, player_c=0): if token_id == EOS_TOKEN: return {'type': 'eos'} elif token_id < 4: return {'type': 'move', 'direction': token_id, 'move_type': 0, 'place_beacon': False} elif token_id < 8: return {'type': 'move', 'direction': token_id - 4, 'move_type': 0, 'place_beacon': True} elif token_id < 12: return {'type': 'move', 'direction': token_id - 8, 'move_type': 1, 'place_beacon': False} elif token_id < 16: return {'type': 'move', 'direction': token_id - 12, 'move_type': 2, 'place_beacon': False} elif token_id < 20: d = token_id - 16; return {'type': 'paint', 'r': player_r + DIR_DR[d], 'c': player_c + DIR_DC[d]} return {'type': 'eos'} def action_to_global(local): return BOARD_VOCAB_SIZE + local def global_to_action(glob): return glob - BOARD_VOCAB_SIZE