maiChartGen / targets.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
3.42 kB
from __future__ import annotations
import torch
from model import DIV_MAP
from tokenizer import (
PAD, RST, SIM_BEG, SIM_END, SLD_BEG, DUR, SIM_COUNT_2,
BRK_BASE, BRK_END, HLD_BASE, HLD_END, ID_TO_DIV,
DUR_NUM_TO_ID, DUR_DEN_TO_ID,
SLD_BASE, SLD_END, TAP_BASE, TAP_END, TCH_BASE, TCH_END,
CONFIG_BASE,
)
TYPE_REST = 0
TYPE_TAP = 1
TYPE_HOLD = 2
TYPE_SLIDE = 3
TYPE_BREAK = 4
TYPE_TOUCH = 5
TYPE_CONTROL = 6
NUM_TOKEN_TYPES = 7
DIV_VALUES = [1, 2, 4, 8, 16, 32, 48, 64, 128, 192, 384]
DIV_VALUE_TO_CLASS = {v: i for i, v in enumerate(DIV_VALUES)}
NUM_DIV_CLASSES = len(DIV_VALUES)
def build_aux_targets(tokens: torch.Tensor) -> dict[str, torch.Tensor]:
"""Build auxiliary labels from next-token targets.
Args:
tokens: [B, T] target token ids, usually chart_tokens[:, 1:].
"""
device = tokens.device
B, T = tokens.shape
presence = torch.zeros(B, T, dtype=torch.long, device=device)
token_type = torch.full((B, T), TYPE_CONTROL, dtype=torch.long, device=device)
position = torch.full((B, T), -100, dtype=torch.long, device=device)
division = torch.full((B, T), -100, dtype=torch.long, device=device)
is_sim = torch.zeros(B, T, dtype=torch.long, device=device)
needs_duration = torch.zeros(B, T, dtype=torch.long, device=device)
is_pad = tokens == PAD
token_type = torch.where(tokens == RST, torch.full_like(token_type, TYPE_REST), token_type)
ranges = [
(TAP_BASE, TAP_END, TYPE_TAP),
(HLD_BASE, HLD_END, TYPE_HOLD),
(SLD_BASE, SLD_END, TYPE_SLIDE),
(BRK_BASE, BRK_END, TYPE_BREAK),
(TCH_BASE, TCH_END, TYPE_TOUCH),
]
for start, end, typ in ranges:
mask = (tokens >= start) & (tokens < end)
token_type = torch.where(mask, torch.full_like(token_type, typ), token_type)
presence = torch.where(mask, torch.ones_like(presence), presence)
position = torch.where(mask, (tokens - start) % 8, position)
group_event = (tokens == SIM_BEG) | (tokens == SLD_BEG)
presence = torch.where(group_event, torch.ones_like(presence), presence)
config_event = tokens >= CONFIG_BASE
presence = torch.where(config_event, torch.ones_like(presence), presence)
sim_event = (tokens == SIM_BEG) | (tokens == SIM_COUNT_2) | (tokens == SIM_END)
is_sim = torch.where(sim_event, torch.ones_like(is_sim), is_sim)
duration_tokens = {DUR} | set(DUR_NUM_TO_ID.values()) | set(DUR_DEN_TO_ID.values())
duration_event = torch.zeros_like(needs_duration, dtype=torch.bool)
for tok in duration_tokens:
duration_event |= tokens == tok
needs_duration = torch.where(duration_event, torch.ones_like(needs_duration), needs_duration)
for div_id, div_value in ID_TO_DIV.items():
mask = tokens == div_id
division = torch.where(mask, torch.full_like(division, DIV_VALUE_TO_CLASS[div_value]), division)
token_type = torch.where(is_pad, torch.full_like(token_type, -100), token_type)
presence = torch.where(is_pad, torch.full_like(presence, -100), presence)
is_sim = torch.where(is_pad, torch.full_like(is_sim, -100), is_sim)
needs_duration = torch.where(is_pad, torch.full_like(needs_duration, -100), needs_duration)
return {
"presence": presence,
"type": token_type,
"position": position,
"division": division,
"sim": is_sim,
"duration": needs_duration,
}