| """ |
| Action tokenizer for ByteFight policy model. |
| |
| Action vocab (21 tokens): |
| 0-3: Move REGULAR (UP, DOWN, LEFT, RIGHT) |
| 4-7: Move REGULAR + PLACE_BEACON |
| 8-11: Move ERASE |
| 12-15: Move BEACON_TRAVEL |
| 16-19: Paint direction (UP, DOWN, LEFT, RIGHT) |
| 20: EOS (end of turn) |
| |
| Board vocab: 2462 tokens (from existing tokenizer) |
| Action tokens offset by BOARD_VOCAB_SIZE in the shared embedding. |
| """ |
| BOARD_VOCAB_SIZE = 2462 |
| NUM_ACTION_TOKENS = 21 |
| TOTAL_VOCAB_SIZE = BOARD_VOCAB_SIZE + NUM_ACTION_TOKENS |
|
|
| BOARD_SEQ_LEN = 970 |
| MAX_ACTIONS = 10 |
| SEQ_LEN = BOARD_SEQ_LEN + MAX_ACTIONS |
|
|
| EOS_TOKEN = 20 |
| EOS_GLOBAL = BOARD_VOCAB_SIZE + EOS_TOKEN |
| IGNORE_INDEX = -100 |
|
|
| DIR_DR = [-1, 1, 0, 0] |
| DIR_DC = [0, 0, -1, 1] |
| DIR_MAP = {'UP': 0, 'DOWN': 1, 'LEFT': 2, 'RIGHT': 3} |
| MT_MAP = {'REGULAR': 0, 'ERASE': 1, 'BEACON_TRAVEL': 2} |
|
|
|
|
| def encode_action(action_dict, player_r=0, player_c=0): |
| atype = action_dict.get('type', action_dict.get('name', '')) |
| if atype in ('move', 'Move'): |
| d = action_dict.get('direction', 0) |
| if isinstance(d, str): d = DIR_MAP.get(d, 0) |
| mt = action_dict.get('move_type', 0) |
| if isinstance(mt, str): mt = MT_MAP.get(mt, 0) |
| pb = action_dict.get('place_beacon', False) |
| if mt == 2: return 12 + d |
| elif mt == 1: return 8 + d |
| elif pb: return 4 + d |
| else: return d |
| elif atype in ('paint', 'Paint'): |
| loc = action_dict.get('location', None) |
| if loc: pr, pc = int(loc[0]), int(loc[1]) |
| else: pr, pc = int(action_dict.get('r', 0)), int(action_dict.get('c', 0)) |
| dr, dc = pr - player_r, pc - player_c |
| for i in range(4): |
| if DIR_DR[i] == dr and DIR_DC[i] == dc: return 16 + i |
| return -1 |
| return -1 |
|
|
|
|
| def decode_action(token_id, player_r=0, player_c=0): |
| if token_id == EOS_TOKEN: return {'type': 'eos'} |
| elif token_id < 4: return {'type': 'move', 'direction': token_id, 'move_type': 0, 'place_beacon': False} |
| elif token_id < 8: return {'type': 'move', 'direction': token_id - 4, 'move_type': 0, 'place_beacon': True} |
| elif token_id < 12: return {'type': 'move', 'direction': token_id - 8, 'move_type': 1, 'place_beacon': False} |
| elif token_id < 16: return {'type': 'move', 'direction': token_id - 12, 'move_type': 2, 'place_beacon': False} |
| elif token_id < 20: d = token_id - 16; return {'type': 'paint', 'r': player_r + DIR_DR[d], 'c': player_c + DIR_DC[d]} |
| return {'type': 'eos'} |
|
|
|
|
| def action_to_global(local): return BOARD_VOCAB_SIZE + local |
| def global_to_action(glob): return glob - BOARD_VOCAB_SIZE |
|
|