| | import json
|
| | import pickle
|
| | from pathlib import Path
|
| |
|
| | import numpy as np
|
| | from omegaconf import DictConfig
|
| | from tqdm import tqdm
|
| |
|
| | from .event import Event, EventType, EventRange
|
| |
|
| | MILISECONDS_PER_SECOND = 1000
|
| | MILISECONDS_PER_STEP = 10
|
| |
|
| |
|
| | class Tokenizer:
|
| | __slots__ = [
|
| | "_offset",
|
| | "event_ranges",
|
| | "input_event_ranges",
|
| | "num_classes",
|
| | "num_diff_classes",
|
| | "max_difficulty",
|
| | "event_range",
|
| | "event_start",
|
| | "event_end",
|
| | "vocab_size_out",
|
| | "vocab_size_in",
|
| | "beatmap_idx",
|
| | ]
|
| |
|
| | def __init__(self, args: DictConfig = None):
|
| | """Fixed vocabulary tokenizer."""
|
| | self._offset = 3
|
| | self.beatmap_idx: dict[int, int] = {}
|
| |
|
| | if args is not None:
|
| | miliseconds_per_sequence = ((args.data.src_seq_len - 1) * args.model.spectrogram.hop_length *
|
| | MILISECONDS_PER_SECOND / args.model.spectrogram.sample_rate)
|
| | max_time_shift = int(miliseconds_per_sequence / MILISECONDS_PER_STEP)
|
| | min_time_shift = -max_time_shift if args.data.add_pre_tokens or args.data.add_pre_tokens_at_step >= 0 else 0
|
| | self.event_ranges = [EventRange(EventType.TIME_SHIFT, min_time_shift, max_time_shift)]
|
| |
|
| | self.input_event_ranges: list[EventRange] = []
|
| | if args.data.style_token_index >= 0:
|
| | self.input_event_ranges.append(EventRange(EventType.STYLE, 0, args.data.num_classes))
|
| | if args.data.diff_token_index >= 0:
|
| | self.input_event_ranges.append(EventRange(EventType.DIFFICULTY, 0, args.data.num_diff_classes))
|
| |
|
| | self.num_classes = args.data.num_classes
|
| | self.num_diff_classes = args.data.num_diff_classes
|
| | self.max_difficulty = args.data.max_diff
|
| |
|
| | self._init_beatmap_idx(args)
|
| | else:
|
| | self.event_ranges = [EventRange(EventType.TIME_SHIFT, -512, 512)]
|
| | self.input_event_ranges = []
|
| | self.num_classes = 0
|
| | self.num_diff_classes = 0
|
| | self.max_difficulty = 0
|
| |
|
| | self.event_ranges: list[EventRange] = self.event_ranges + [
|
| | EventRange(EventType.DISTANCE, 0, 640),
|
| | EventRange(EventType.NEW_COMBO, 0, 0),
|
| | EventRange(EventType.CIRCLE, 0, 0),
|
| | EventRange(EventType.SPINNER, 0, 0),
|
| | EventRange(EventType.SPINNER_END, 0, 0),
|
| | EventRange(EventType.SLIDER_HEAD, 0, 0),
|
| | EventRange(EventType.BEZIER_ANCHOR, 0, 0),
|
| | EventRange(EventType.PERFECT_ANCHOR, 0, 0),
|
| | EventRange(EventType.CATMULL_ANCHOR, 0, 0),
|
| | EventRange(EventType.RED_ANCHOR, 0, 0),
|
| | EventRange(EventType.LAST_ANCHOR, 0, 0),
|
| | EventRange(EventType.SLIDER_END, 0, 0),
|
| | ]
|
| |
|
| | self.event_range: dict[EventType, EventRange] = {er.type: er for er in self.event_ranges} | {er.type: er for er in self.input_event_ranges}
|
| |
|
| | self.event_start: dict[EventType, int] = {}
|
| | self.event_end: dict[EventType, int] = {}
|
| | offset = self._offset
|
| | for er in self.event_ranges:
|
| | self.event_start[er.type] = offset
|
| | offset += er.max_value - er.min_value + 1
|
| | self.event_end[er.type] = offset
|
| | for er in self.input_event_ranges:
|
| | self.event_start[er.type] = offset
|
| | offset += er.max_value - er.min_value + 1
|
| | self.event_end[er.type] = offset
|
| |
|
| | self.vocab_size_out: int = self._offset + sum(
|
| | er.max_value - er.min_value + 1 for er in self.event_ranges
|
| | )
|
| | self.vocab_size_in: int = self.vocab_size_out + sum(
|
| | er.max_value - er.min_value + 1 for er in self.input_event_ranges
|
| | )
|
| |
|
| | @property
|
| | def pad_id(self) -> int:
|
| | """[PAD] token for padding."""
|
| | return 0
|
| |
|
| | @property
|
| | def sos_id(self) -> int:
|
| | """[SOS] token for start-of-sequence."""
|
| | return 1
|
| |
|
| | @property
|
| | def eos_id(self) -> int:
|
| | """[EOS] token for end-of-sequence."""
|
| | return 2
|
| |
|
| | def decode(self, token_id: int) -> Event:
|
| | """Converts token ids into Event objects."""
|
| | offset = self._offset
|
| | for er in self.event_ranges:
|
| | if offset <= token_id <= offset + er.max_value - er.min_value:
|
| | return Event(type=er.type, value=er.min_value + token_id - offset)
|
| | offset += er.max_value - er.min_value + 1
|
| | for er in self.input_event_ranges:
|
| | if offset <= token_id <= offset + er.max_value - er.min_value:
|
| | return Event(type=er.type, value=er.min_value + token_id - offset)
|
| | offset += er.max_value - er.min_value + 1
|
| |
|
| | raise ValueError(f"id {token_id} is not mapped to any event")
|
| |
|
| | def encode(self, event: Event) -> int:
|
| | """Converts Event objects into token ids."""
|
| | if event.type not in self.event_range:
|
| | raise ValueError(f"unknown event type: {event.type}")
|
| |
|
| | er = self.event_range[event.type]
|
| | offset = self.event_start[event.type]
|
| |
|
| | if not er.min_value <= event.value <= er.max_value:
|
| | raise ValueError(
|
| | f"event value {event.value} in {event} is not within range "
|
| | f"[{er.min_value}, {er.max_value}] for event type {event.type}"
|
| | )
|
| |
|
| | return offset + event.value - er.min_value
|
| |
|
| | def event_type_range(self, event_type: EventType) -> tuple[int, int]:
|
| | """Get the token id range of each Event type."""
|
| | if event_type not in self.event_range:
|
| | raise ValueError(f"unknown event type: {event_type}")
|
| |
|
| | er = self.event_range[event_type]
|
| | offset = self.event_start[event_type]
|
| | return offset, offset + (er.max_value - er.min_value)
|
| |
|
| | def encode_diff_event(self, diff: float) -> Event:
|
| | """Converts difficulty value into event."""
|
| | return Event(type=EventType.DIFFICULTY, value=np.clip(
|
| | int(diff * self.num_diff_classes / self.max_difficulty), 0, self.num_diff_classes - 1))
|
| |
|
| | def encode_diff(self, diff: float) -> int:
|
| | """Converts difficulty value into token id."""
|
| | return self.encode(self.encode_diff_event(diff))
|
| |
|
| | @property
|
| | def diff_unk(self) -> int:
|
| | """Gets the unknown difficulty value token id."""
|
| | return self.encode(Event(type=EventType.DIFFICULTY, value=self.num_diff_classes))
|
| |
|
| | def encode_style_event(self, beatmap_id: int) -> Event:
|
| | """Converts beatmap id into style event."""
|
| | style_idx = self.beatmap_idx.get(beatmap_id, self.num_classes)
|
| | return Event(type=EventType.STYLE, value=style_idx)
|
| |
|
| | def encode_style(self, beatmap_id: int) -> int:
|
| | """Converts beatmap id into token id."""
|
| | return self.encode(self.encode_style_event(beatmap_id))
|
| |
|
| | def encode_style_idx(self, beatmap_idx: int) -> int:
|
| | """Converts beatmap idx into token id."""
|
| | return self.encode(Event(type=EventType.STYLE, value=beatmap_idx))
|
| |
|
| | @property
|
| | def style_unk(self) -> int:
|
| | """Gets the unknown style value token id."""
|
| | return self.encode(Event(type=EventType.STYLE, value=self.num_classes))
|
| |
|
| | def _init_beatmap_idx(self, args: DictConfig) -> None:
|
| | """Initializes and caches the beatmap index."""
|
| | if args is None or "train_dataset_path" not in args.data:
|
| | return
|
| |
|
| | path = Path(args.data.train_dataset_path)
|
| | cache_path = path / "beatmap_idx.pickle"
|
| |
|
| | if cache_path.exists():
|
| | with open(path / "beatmap_idx.pickle", "rb") as f:
|
| | self.beatmap_idx = pickle.load(f)
|
| | return
|
| |
|
| | print("Caching beatmap index...")
|
| |
|
| | for track in tqdm(path.iterdir()):
|
| | if not track.is_dir():
|
| | continue
|
| | metadata_file = track / "metadata.json"
|
| | with open(metadata_file) as f:
|
| | metadata = json.load(f)
|
| | for beatmap_name in metadata["Beatmaps"]:
|
| | beatmap_metadata = metadata["Beatmaps"][beatmap_name]
|
| | self.beatmap_idx[beatmap_metadata["BeatmapId"]] = beatmap_metadata["Index"]
|
| |
|
| | with open(cache_path, "wb") as f:
|
| | pickle.dump(self.beatmap_idx, f)
|
| |
|
| | def state_dict(self):
|
| | return {
|
| | "event_ranges": self.event_ranges,
|
| | "input_event_ranges": self.input_event_ranges,
|
| | "num_classes": self.num_classes,
|
| | "num_diff_classes": self.num_diff_classes,
|
| | "max_difficulty": self.max_difficulty,
|
| | "event_range": self.event_range,
|
| | "event_start": self.event_start,
|
| | "event_end": self.event_end,
|
| | "vocab_size_out": self.vocab_size_out,
|
| | "vocab_size_in": self.vocab_size_in,
|
| | "beatmap_idx": self.beatmap_idx,
|
| | }
|
| |
|
| | def load_state_dict(self, state_dict):
|
| | self.event_ranges = state_dict["event_ranges"]
|
| | self.input_event_ranges = state_dict["input_event_ranges"]
|
| | self.num_classes = state_dict["num_classes"]
|
| | self.num_diff_classes = state_dict["num_diff_classes"]
|
| | self.max_difficulty = state_dict["max_difficulty"]
|
| | self.event_range = state_dict["event_range"]
|
| | self.event_start = state_dict["event_start"]
|
| | self.event_end = state_dict["event_end"]
|
| | self.vocab_size_out = state_dict["vocab_size_out"]
|
| | self.vocab_size_in = state_dict["vocab_size_in"]
|
| | self.beatmap_idx = state_dict["beatmap_idx"]
|
| |
|