osu_mapper2 / osuT5 /tokenizer /tokenizer.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
import json
import pickle
from pathlib import Path
import numpy as np
from omegaconf import DictConfig
from tqdm import tqdm
from .event import Event, EventType, EventRange
MILISECONDS_PER_SECOND = 1000
MILISECONDS_PER_STEP = 10
class Tokenizer:
__slots__ = [
"_offset",
"event_ranges",
"input_event_ranges",
"num_classes",
"num_diff_classes",
"max_difficulty",
"event_range",
"event_start",
"event_end",
"vocab_size_out",
"vocab_size_in",
"beatmap_idx",
]
def __init__(self, args: DictConfig = None):
"""Fixed vocabulary tokenizer."""
self._offset = 3
self.beatmap_idx: dict[int, int] = {}
if args is not None:
miliseconds_per_sequence = ((args.data.src_seq_len - 1) * args.model.spectrogram.hop_length *
MILISECONDS_PER_SECOND / args.model.spectrogram.sample_rate)
max_time_shift = int(miliseconds_per_sequence / MILISECONDS_PER_STEP)
min_time_shift = -max_time_shift if args.data.add_pre_tokens or args.data.add_pre_tokens_at_step >= 0 else 0
self.event_ranges = [EventRange(EventType.TIME_SHIFT, min_time_shift, max_time_shift)]
self.input_event_ranges: list[EventRange] = []
if args.data.style_token_index >= 0:
self.input_event_ranges.append(EventRange(EventType.STYLE, 0, args.data.num_classes))
if args.data.diff_token_index >= 0:
self.input_event_ranges.append(EventRange(EventType.DIFFICULTY, 0, args.data.num_diff_classes))
self.num_classes = args.data.num_classes
self.num_diff_classes = args.data.num_diff_classes
self.max_difficulty = args.data.max_diff
self._init_beatmap_idx(args)
else:
self.event_ranges = [EventRange(EventType.TIME_SHIFT, -512, 512)]
self.input_event_ranges = []
self.num_classes = 0
self.num_diff_classes = 0
self.max_difficulty = 0
self.event_ranges: list[EventRange] = self.event_ranges + [
EventRange(EventType.DISTANCE, 0, 640),
EventRange(EventType.NEW_COMBO, 0, 0),
EventRange(EventType.CIRCLE, 0, 0),
EventRange(EventType.SPINNER, 0, 0),
EventRange(EventType.SPINNER_END, 0, 0),
EventRange(EventType.SLIDER_HEAD, 0, 0),
EventRange(EventType.BEZIER_ANCHOR, 0, 0),
EventRange(EventType.PERFECT_ANCHOR, 0, 0),
EventRange(EventType.CATMULL_ANCHOR, 0, 0),
EventRange(EventType.RED_ANCHOR, 0, 0),
EventRange(EventType.LAST_ANCHOR, 0, 0),
EventRange(EventType.SLIDER_END, 0, 0),
]
self.event_range: dict[EventType, EventRange] = {er.type: er for er in self.event_ranges} | {er.type: er for er in self.input_event_ranges}
self.event_start: dict[EventType, int] = {}
self.event_end: dict[EventType, int] = {}
offset = self._offset
for er in self.event_ranges:
self.event_start[er.type] = offset
offset += er.max_value - er.min_value + 1
self.event_end[er.type] = offset
for er in self.input_event_ranges:
self.event_start[er.type] = offset
offset += er.max_value - er.min_value + 1
self.event_end[er.type] = offset
self.vocab_size_out: int = self._offset + sum(
er.max_value - er.min_value + 1 for er in self.event_ranges
)
self.vocab_size_in: int = self.vocab_size_out + sum(
er.max_value - er.min_value + 1 for er in self.input_event_ranges
)
@property
def pad_id(self) -> int:
"""[PAD] token for padding."""
return 0
@property
def sos_id(self) -> int:
"""[SOS] token for start-of-sequence."""
return 1
@property
def eos_id(self) -> int:
"""[EOS] token for end-of-sequence."""
return 2
def decode(self, token_id: int) -> Event:
"""Converts token ids into Event objects."""
offset = self._offset
for er in self.event_ranges:
if offset <= token_id <= offset + er.max_value - er.min_value:
return Event(type=er.type, value=er.min_value + token_id - offset)
offset += er.max_value - er.min_value + 1
for er in self.input_event_ranges:
if offset <= token_id <= offset + er.max_value - er.min_value:
return Event(type=er.type, value=er.min_value + token_id - offset)
offset += er.max_value - er.min_value + 1
raise ValueError(f"id {token_id} is not mapped to any event")
def encode(self, event: Event) -> int:
"""Converts Event objects into token ids."""
if event.type not in self.event_range:
raise ValueError(f"unknown event type: {event.type}")
er = self.event_range[event.type]
offset = self.event_start[event.type]
if not er.min_value <= event.value <= er.max_value:
raise ValueError(
f"event value {event.value} in {event} is not within range "
f"[{er.min_value}, {er.max_value}] for event type {event.type}"
)
return offset + event.value - er.min_value
def event_type_range(self, event_type: EventType) -> tuple[int, int]:
"""Get the token id range of each Event type."""
if event_type not in self.event_range:
raise ValueError(f"unknown event type: {event_type}")
er = self.event_range[event_type]
offset = self.event_start[event_type]
return offset, offset + (er.max_value - er.min_value)
def encode_diff_event(self, diff: float) -> Event:
"""Converts difficulty value into event."""
return Event(type=EventType.DIFFICULTY, value=np.clip(
int(diff * self.num_diff_classes / self.max_difficulty), 0, self.num_diff_classes - 1))
def encode_diff(self, diff: float) -> int:
"""Converts difficulty value into token id."""
return self.encode(self.encode_diff_event(diff))
@property
def diff_unk(self) -> int:
"""Gets the unknown difficulty value token id."""
return self.encode(Event(type=EventType.DIFFICULTY, value=self.num_diff_classes))
def encode_style_event(self, beatmap_id: int) -> Event:
"""Converts beatmap id into style event."""
style_idx = self.beatmap_idx.get(beatmap_id, self.num_classes)
return Event(type=EventType.STYLE, value=style_idx)
def encode_style(self, beatmap_id: int) -> int:
"""Converts beatmap id into token id."""
return self.encode(self.encode_style_event(beatmap_id))
def encode_style_idx(self, beatmap_idx: int) -> int:
"""Converts beatmap idx into token id."""
return self.encode(Event(type=EventType.STYLE, value=beatmap_idx))
@property
def style_unk(self) -> int:
"""Gets the unknown style value token id."""
return self.encode(Event(type=EventType.STYLE, value=self.num_classes))
def _init_beatmap_idx(self, args: DictConfig) -> None:
"""Initializes and caches the beatmap index."""
if args is None or "train_dataset_path" not in args.data:
return
path = Path(args.data.train_dataset_path)
cache_path = path / "beatmap_idx.pickle"
if cache_path.exists():
with open(path / "beatmap_idx.pickle", "rb") as f:
self.beatmap_idx = pickle.load(f)
return
print("Caching beatmap index...")
for track in tqdm(path.iterdir()):
if not track.is_dir():
continue
metadata_file = track / "metadata.json"
with open(metadata_file) as f:
metadata = json.load(f)
for beatmap_name in metadata["Beatmaps"]:
beatmap_metadata = metadata["Beatmaps"][beatmap_name]
self.beatmap_idx[beatmap_metadata["BeatmapId"]] = beatmap_metadata["Index"]
with open(cache_path, "wb") as f:
pickle.dump(self.beatmap_idx, f)
def state_dict(self):
return {
"event_ranges": self.event_ranges,
"input_event_ranges": self.input_event_ranges,
"num_classes": self.num_classes,
"num_diff_classes": self.num_diff_classes,
"max_difficulty": self.max_difficulty,
"event_range": self.event_range,
"event_start": self.event_start,
"event_end": self.event_end,
"vocab_size_out": self.vocab_size_out,
"vocab_size_in": self.vocab_size_in,
"beatmap_idx": self.beatmap_idx,
}
def load_state_dict(self, state_dict):
self.event_ranges = state_dict["event_ranges"]
self.input_event_ranges = state_dict["input_event_ranges"]
self.num_classes = state_dict["num_classes"]
self.num_diff_classes = state_dict["num_diff_classes"]
self.max_difficulty = state_dict["max_difficulty"]
self.event_range = state_dict["event_range"]
self.event_start = state_dict["event_start"]
self.event_end = state_dict["event_end"]
self.vocab_size_out = state_dict["vocab_size_out"]
self.vocab_size_in = state_dict["vocab_size_in"]
self.beatmap_idx = state_dict["beatmap_idx"]