| from __future__ import annotations |
|
|
| import torch |
|
|
| from model import DIV_MAP |
| from tokenizer import ( |
| PAD, RST, SIM_BEG, SIM_END, SLD_BEG, DUR, SIM_COUNT_2, |
| BRK_BASE, BRK_END, HLD_BASE, HLD_END, ID_TO_DIV, |
| DUR_NUM_TO_ID, DUR_DEN_TO_ID, |
| SLD_BASE, SLD_END, TAP_BASE, TAP_END, TCH_BASE, TCH_END, |
| CONFIG_BASE, |
| ) |
|
|
| TYPE_REST = 0 |
| TYPE_TAP = 1 |
| TYPE_HOLD = 2 |
| TYPE_SLIDE = 3 |
| TYPE_BREAK = 4 |
| TYPE_TOUCH = 5 |
| TYPE_CONTROL = 6 |
| NUM_TOKEN_TYPES = 7 |
|
|
| DIV_VALUES = [1, 2, 4, 8, 16, 32, 48, 64, 128, 192, 384] |
| DIV_VALUE_TO_CLASS = {v: i for i, v in enumerate(DIV_VALUES)} |
| NUM_DIV_CLASSES = len(DIV_VALUES) |
|
|
|
|
| def build_aux_targets(tokens: torch.Tensor) -> dict[str, torch.Tensor]: |
| """Build auxiliary labels from next-token targets. |
| |
| Args: |
| tokens: [B, T] target token ids, usually chart_tokens[:, 1:]. |
| """ |
| device = tokens.device |
| B, T = tokens.shape |
|
|
| presence = torch.zeros(B, T, dtype=torch.long, device=device) |
| token_type = torch.full((B, T), TYPE_CONTROL, dtype=torch.long, device=device) |
| position = torch.full((B, T), -100, dtype=torch.long, device=device) |
| division = torch.full((B, T), -100, dtype=torch.long, device=device) |
| is_sim = torch.zeros(B, T, dtype=torch.long, device=device) |
| needs_duration = torch.zeros(B, T, dtype=torch.long, device=device) |
|
|
| is_pad = tokens == PAD |
| token_type = torch.where(tokens == RST, torch.full_like(token_type, TYPE_REST), token_type) |
|
|
| ranges = [ |
| (TAP_BASE, TAP_END, TYPE_TAP), |
| (HLD_BASE, HLD_END, TYPE_HOLD), |
| (SLD_BASE, SLD_END, TYPE_SLIDE), |
| (BRK_BASE, BRK_END, TYPE_BREAK), |
| (TCH_BASE, TCH_END, TYPE_TOUCH), |
| ] |
| for start, end, typ in ranges: |
| mask = (tokens >= start) & (tokens < end) |
| token_type = torch.where(mask, torch.full_like(token_type, typ), token_type) |
| presence = torch.where(mask, torch.ones_like(presence), presence) |
| position = torch.where(mask, (tokens - start) % 8, position) |
|
|
| group_event = (tokens == SIM_BEG) | (tokens == SLD_BEG) |
| presence = torch.where(group_event, torch.ones_like(presence), presence) |
| config_event = tokens >= CONFIG_BASE |
| presence = torch.where(config_event, torch.ones_like(presence), presence) |
| sim_event = (tokens == SIM_BEG) | (tokens == SIM_COUNT_2) | (tokens == SIM_END) |
| is_sim = torch.where(sim_event, torch.ones_like(is_sim), is_sim) |
| duration_tokens = {DUR} | set(DUR_NUM_TO_ID.values()) | set(DUR_DEN_TO_ID.values()) |
| duration_event = torch.zeros_like(needs_duration, dtype=torch.bool) |
| for tok in duration_tokens: |
| duration_event |= tokens == tok |
| needs_duration = torch.where(duration_event, torch.ones_like(needs_duration), needs_duration) |
|
|
| for div_id, div_value in ID_TO_DIV.items(): |
| mask = tokens == div_id |
| division = torch.where(mask, torch.full_like(division, DIV_VALUE_TO_CLASS[div_value]), division) |
|
|
| token_type = torch.where(is_pad, torch.full_like(token_type, -100), token_type) |
| presence = torch.where(is_pad, torch.full_like(presence, -100), presence) |
| is_sim = torch.where(is_pad, torch.full_like(is_sim, -100), is_sim) |
| needs_duration = torch.where(is_pad, torch.full_like(needs_duration, -100), needs_duration) |
|
|
| return { |
| "presence": presence, |
| "type": token_type, |
| "position": position, |
| "division": division, |
| "sim": is_sim, |
| "duration": needs_duration, |
| } |
|
|