bytefight-policy / action_tokenizer.py
Broyojo's picture
Upload action_tokenizer.py with huggingface_hub
1938d3a verified
"""
Action tokenizer for ByteFight policy model.
Action vocab (21 tokens):
0-3: Move REGULAR (UP, DOWN, LEFT, RIGHT)
4-7: Move REGULAR + PLACE_BEACON
8-11: Move ERASE
12-15: Move BEACON_TRAVEL
16-19: Paint direction (UP, DOWN, LEFT, RIGHT)
20: EOS (end of turn)
Board vocab: 2462 tokens (from existing tokenizer)
Action tokens offset by BOARD_VOCAB_SIZE in the shared embedding.
"""
BOARD_VOCAB_SIZE = 2462
NUM_ACTION_TOKENS = 21
TOTAL_VOCAB_SIZE = BOARD_VOCAB_SIZE + NUM_ACTION_TOKENS # 2483
BOARD_SEQ_LEN = 970 # 9 scalars + 961 cells (no CLS)
MAX_ACTIONS = 10 # max actions per turn (from data analysis)
SEQ_LEN = BOARD_SEQ_LEN + MAX_ACTIONS # 980
EOS_TOKEN = 20 # local
EOS_GLOBAL = BOARD_VOCAB_SIZE + EOS_TOKEN # 2482
IGNORE_INDEX = -100 # for loss masking
DIR_DR = [-1, 1, 0, 0]
DIR_DC = [0, 0, -1, 1]
DIR_MAP = {'UP': 0, 'DOWN': 1, 'LEFT': 2, 'RIGHT': 3}
MT_MAP = {'REGULAR': 0, 'ERASE': 1, 'BEACON_TRAVEL': 2}
def encode_action(action_dict, player_r=0, player_c=0):
atype = action_dict.get('type', action_dict.get('name', ''))
if atype in ('move', 'Move'):
d = action_dict.get('direction', 0)
if isinstance(d, str): d = DIR_MAP.get(d, 0)
mt = action_dict.get('move_type', 0)
if isinstance(mt, str): mt = MT_MAP.get(mt, 0)
pb = action_dict.get('place_beacon', False)
if mt == 2: return 12 + d
elif mt == 1: return 8 + d
elif pb: return 4 + d
else: return d
elif atype in ('paint', 'Paint'):
loc = action_dict.get('location', None)
if loc: pr, pc = int(loc[0]), int(loc[1])
else: pr, pc = int(action_dict.get('r', 0)), int(action_dict.get('c', 0))
dr, dc = pr - player_r, pc - player_c
for i in range(4):
if DIR_DR[i] == dr and DIR_DC[i] == dc: return 16 + i
return -1
return -1
def decode_action(token_id, player_r=0, player_c=0):
if token_id == EOS_TOKEN: return {'type': 'eos'}
elif token_id < 4: return {'type': 'move', 'direction': token_id, 'move_type': 0, 'place_beacon': False}
elif token_id < 8: return {'type': 'move', 'direction': token_id - 4, 'move_type': 0, 'place_beacon': True}
elif token_id < 12: return {'type': 'move', 'direction': token_id - 8, 'move_type': 1, 'place_beacon': False}
elif token_id < 16: return {'type': 'move', 'direction': token_id - 12, 'move_type': 2, 'place_beacon': False}
elif token_id < 20: d = token_id - 16; return {'type': 'paint', 'r': player_r + DIR_DR[d], 'c': player_c + DIR_DC[d]}
return {'type': 'eos'}
def action_to_global(local): return BOARD_VOCAB_SIZE + local
def global_to_action(glob): return glob - BOARD_VOCAB_SIZE