from __future__ import annotations import torch from model import DIV_MAP from tokenizer import ( PAD, RST, SIM_BEG, SIM_END, SLD_BEG, DUR, SIM_COUNT_2, BRK_BASE, BRK_END, HLD_BASE, HLD_END, ID_TO_DIV, DUR_NUM_TO_ID, DUR_DEN_TO_ID, SLD_BASE, SLD_END, TAP_BASE, TAP_END, TCH_BASE, TCH_END, CONFIG_BASE, ) TYPE_REST = 0 TYPE_TAP = 1 TYPE_HOLD = 2 TYPE_SLIDE = 3 TYPE_BREAK = 4 TYPE_TOUCH = 5 TYPE_CONTROL = 6 NUM_TOKEN_TYPES = 7 DIV_VALUES = [1, 2, 4, 8, 16, 32, 48, 64, 128, 192, 384] DIV_VALUE_TO_CLASS = {v: i for i, v in enumerate(DIV_VALUES)} NUM_DIV_CLASSES = len(DIV_VALUES) def build_aux_targets(tokens: torch.Tensor) -> dict[str, torch.Tensor]: """Build auxiliary labels from next-token targets. Args: tokens: [B, T] target token ids, usually chart_tokens[:, 1:]. """ device = tokens.device B, T = tokens.shape presence = torch.zeros(B, T, dtype=torch.long, device=device) token_type = torch.full((B, T), TYPE_CONTROL, dtype=torch.long, device=device) position = torch.full((B, T), -100, dtype=torch.long, device=device) division = torch.full((B, T), -100, dtype=torch.long, device=device) is_sim = torch.zeros(B, T, dtype=torch.long, device=device) needs_duration = torch.zeros(B, T, dtype=torch.long, device=device) is_pad = tokens == PAD token_type = torch.where(tokens == RST, torch.full_like(token_type, TYPE_REST), token_type) ranges = [ (TAP_BASE, TAP_END, TYPE_TAP), (HLD_BASE, HLD_END, TYPE_HOLD), (SLD_BASE, SLD_END, TYPE_SLIDE), (BRK_BASE, BRK_END, TYPE_BREAK), (TCH_BASE, TCH_END, TYPE_TOUCH), ] for start, end, typ in ranges: mask = (tokens >= start) & (tokens < end) token_type = torch.where(mask, torch.full_like(token_type, typ), token_type) presence = torch.where(mask, torch.ones_like(presence), presence) position = torch.where(mask, (tokens - start) % 8, position) group_event = (tokens == SIM_BEG) | (tokens == SLD_BEG) presence = torch.where(group_event, torch.ones_like(presence), presence) config_event = tokens >= CONFIG_BASE presence = torch.where(config_event, torch.ones_like(presence), presence) sim_event = (tokens == SIM_BEG) | (tokens == SIM_COUNT_2) | (tokens == SIM_END) is_sim = torch.where(sim_event, torch.ones_like(is_sim), is_sim) duration_tokens = {DUR} | set(DUR_NUM_TO_ID.values()) | set(DUR_DEN_TO_ID.values()) duration_event = torch.zeros_like(needs_duration, dtype=torch.bool) for tok in duration_tokens: duration_event |= tokens == tok needs_duration = torch.where(duration_event, torch.ones_like(needs_duration), needs_duration) for div_id, div_value in ID_TO_DIV.items(): mask = tokens == div_id division = torch.where(mask, torch.full_like(division, DIV_VALUE_TO_CLASS[div_value]), division) token_type = torch.where(is_pad, torch.full_like(token_type, -100), token_type) presence = torch.where(is_pad, torch.full_like(presence, -100), presence) is_sim = torch.where(is_pad, torch.full_like(is_sim, -100), is_sim) needs_duration = torch.where(is_pad, torch.full_like(needs_duration, -100), needs_duration) return { "presence": presence, "type": token_type, "position": position, "division": division, "sim": is_sim, "duration": needs_duration, }