| import json |
| from pathlib import Path |
|
|
| from omegaconf import DictConfig |
|
|
| from .event import Event, EventType, EventRange |
|
|
| MILISECONDS_PER_SECOND = 1000 |
| MILISECONDS_PER_STEP = 10 |
|
|
|
|
| class Tokenizer: |
| __slots__ = [ |
| "offset", |
| "event_ranges", |
| "input_event_ranges", |
| "num_classes", |
| "num_diff_classes", |
| "max_difficulty", |
| "event_range", |
| "event_start", |
| "event_end", |
| "vocab_size", |
| "beatmap_idx", |
| "mapper_idx", |
| "beatmap_mapper", |
| "num_mapper_classes", |
| "beatmap_descriptors", |
| "descriptor_idx", |
| "num_descriptor_classes", |
| "num_cs_classes", |
| ] |
|
|
| def __init__(self, args: DictConfig = None): |
| """Fixed vocabulary tokenizer.""" |
| self.offset = 1 |
| self.event_ranges: list[EventRange] = [ |
| EventRange(EventType.TIME_SHIFT, 0, 1024), |
| EventRange(EventType.SNAPPING, 0, 16), |
| EventRange(EventType.DISTANCE, 0, 640), |
| ] |
| self.num_classes = 0 |
| self.beatmap_mapper: dict[int, int] = {} |
| self.mapper_idx: dict[int, int] = {} |
|
|
| if args is not None: |
| miliseconds_per_sequence = ((args.data.src_seq_len - 1) * args.model.spectrogram.hop_length * |
| MILISECONDS_PER_SECOND / args.model.spectrogram.sample_rate) |
| max_time_shift = int(miliseconds_per_sequence / MILISECONDS_PER_STEP) |
| min_time_shift = 0 |
|
|
| self.event_ranges = [ |
| EventRange(EventType.TIME_SHIFT, min_time_shift, max_time_shift), |
| EventRange(EventType.SNAPPING, 0, 16), |
| ] |
|
|
| self._init_mapper_idx(args) |
|
|
| if args.data.add_distances: |
| self.event_ranges.append(EventRange(EventType.DISTANCE, 0, 640)) |
|
|
| if args.data.add_positions: |
| p = args.data.position_precision |
| x_min, x_max, y_min, y_max = args.data.position_range |
| x_min, x_max, y_min, y_max = x_min // p, x_max // p, y_min // p, y_max // p |
|
|
| if args.data.position_split_axes: |
| self.event_ranges.append(EventRange(EventType.POS_X, x_min, x_max)) |
| self.event_ranges.append(EventRange(EventType.POS_Y, y_min, y_max)) |
| else: |
| x_count = x_max - x_min + 1 |
| y_count = y_max - y_min + 1 |
| self.event_ranges.append(EventRange(EventType.POS, 0, x_count * y_count - 1)) |
|
|
| self.event_ranges: list[EventRange] = self.event_ranges + [ |
| EventRange(EventType.NEW_COMBO, 0, 0), |
| EventRange(EventType.HITSOUND, 0, 2 ** 3 * 3 * 3), |
| EventRange(EventType.VOLUME, 0, 100), |
| EventRange(EventType.CIRCLE, 0, 0), |
| EventRange(EventType.SPINNER, 0, 0), |
| EventRange(EventType.SPINNER_END, 0, 0), |
| EventRange(EventType.SLIDER_HEAD, 0, 0), |
| EventRange(EventType.BEZIER_ANCHOR, 0, 0), |
| EventRange(EventType.PERFECT_ANCHOR, 0, 0), |
| EventRange(EventType.CATMULL_ANCHOR, 0, 0), |
| EventRange(EventType.RED_ANCHOR, 0, 0), |
| EventRange(EventType.LAST_ANCHOR, 0, 0), |
| EventRange(EventType.SLIDER_END, 0, 0), |
| EventRange(EventType.BEAT, 0, 0), |
| EventRange(EventType.MEASURE, 0, 0), |
| ] |
|
|
| if args is not None and args.data.add_timing_points: |
| self.event_ranges.append(EventRange(EventType.TIMING_POINT, 0, 0)) |
|
|
| self.event_range: dict[EventType, EventRange] = {er.type: er for er in self.event_ranges} |
|
|
| self.event_start: dict[EventType, int] = {} |
| self.event_end: dict[EventType, int] = {} |
| offset = self.offset |
| for er in self.event_ranges: |
| self.event_start[er.type] = offset |
| offset += er.max_value - er.min_value + 1 |
| self.event_end[er.type] = offset |
|
|
| self.vocab_size: int = self.offset + sum( |
| er.max_value - er.min_value + 1 for er in self.event_ranges |
| ) |
|
|
| @property |
| def pad_id(self) -> int: |
| """[PAD] token for padding.""" |
| return 0 |
|
|
| def decode(self, token_id: int) -> Event: |
| """Converts token ids into Event objects.""" |
| offset = self.offset |
| for er in self.event_ranges: |
| if offset <= token_id <= offset + er.max_value - er.min_value: |
| return Event(type=er.type, value=er.min_value + token_id - offset) |
| offset += er.max_value - er.min_value + 1 |
| for er in self.input_event_ranges: |
| if offset <= token_id <= offset + er.max_value - er.min_value: |
| return Event(type=er.type, value=er.min_value + token_id - offset) |
| offset += er.max_value - er.min_value + 1 |
|
|
| raise ValueError(f"id {token_id} is not mapped to any event") |
|
|
| def encode(self, event: Event) -> int: |
| """Converts Event objects into token ids.""" |
| if event.type not in self.event_range: |
| raise ValueError(f"unknown event type: {event.type}") |
|
|
| er = self.event_range[event.type] |
| offset = self.event_start[event.type] |
|
|
| if not er.min_value <= event.value <= er.max_value: |
| raise ValueError( |
| f"event value {event.value} is not within range " |
| f"[{er.min_value}, {er.max_value}] for event type {event.type}" |
| ) |
|
|
| return offset + event.value - er.min_value |
|
|
| def event_type_range(self, event_type: EventType) -> tuple[int, int]: |
| """Get the token id range of each Event type.""" |
| if event_type not in self.event_range: |
| raise ValueError(f"unknown event type: {event_type}") |
|
|
| er = self.event_range[event_type] |
| offset = self.event_start[event_type] |
| return offset, offset + (er.max_value - er.min_value) |
|
|
| def _init_mapper_idx(self, args): |
| """"Indexes beatmap mappers and mapper idx.""" |
| if args is None or "mappers_path" not in args.data: |
| raise ValueError("mappers_path not found in args") |
|
|
| path = Path(args.data.mappers_path) |
|
|
| if not path.exists(): |
| raise ValueError(f"mappers_path {path} not found") |
|
|
| |
| with open(path, 'r') as file: |
| data = json.load(file) |
|
|
| |
| for item in data: |
| self.beatmap_mapper[item['id']] = item['user_id'] |
|
|
| |
| unique_user_ids = list(set(self.beatmap_mapper.values())) |
|
|
| |
| self.mapper_idx = {user_id: idx for idx, user_id in enumerate(unique_user_ids)} |
| self.num_classes = len(unique_user_ids) |
|
|
| def state_dict(self): |
| return { |
| "offset": self.offset, |
| "event_ranges": self.event_ranges, |
| "num_classes": self.num_classes, |
| "event_range": self.event_range, |
| "event_start": self.event_start, |
| "event_end": self.event_end, |
| "vocab_size": self.vocab_size, |
| "beatmap_mapper": self.beatmap_mapper, |
| "mapper_idx": self.mapper_idx, |
| } |
|
|
| def load_state_dict(self, state_dict): |
| self.offset = state_dict["offset"] |
| self.event_ranges = state_dict["event_ranges"] |
| self.num_classes = state_dict["num_classes"] |
| self.event_range = state_dict["event_range"] |
| self.event_start = state_dict["event_start"] |
| self.event_end = state_dict["event_end"] |
| self.vocab_size = state_dict["vocab_size"] |
| self.beatmap_mapper = state_dict["beatmap_mapper"] |
| self.mapper_idx = state_dict["mapper_idx"] |
|
|