fourmansyah's picture
Duplicate from hongminh54/BeatHeritage-v1
12a8e0f
import json
import pickle
from pathlib import Path
import numpy as np
from tqdm import tqdm
from osu_diffusion.config import DiffusionTrainConfig
class Tokenizer:
__slots__ = [
"num_classes",
"num_diff_classes",
"max_difficulty",
"beatmap_idx",
"mapper_idx",
"beatmap_mapper",
"num_mapper_classes",
"beatmap_descriptors",
"descriptor_idx",
"num_descriptor_classes",
"num_cs_classes",
]
def __init__(self, args: DiffusionTrainConfig = None):
"""Fixed vocabulary tokenizer."""
self.beatmap_idx: dict[int, int] = {} # beatmap_id -> beatmap_idx
self.num_classes = 0
self.num_diff_classes = 0
self.max_difficulty = 0
self.beatmap_mapper: dict[int, int] = {} # beatmap_id -> mapper_id
self.mapper_idx: dict[int, int] = {} # mapper_id -> mapper_idx
self.num_mapper_classes = 0
self.beatmap_descriptors: dict[int, list[int]] = {} # beatmap_id -> [descriptor_idx]
self.descriptor_idx: dict[str, int] = {} # descriptor_name -> descriptor_idx
self.num_descriptor_classes = 0
self.num_cs_classes = 0
if args is not None:
if args.data.beatmap_class:
self._init_beatmap_idx(args)
if args.data.difficulty_class:
self.num_diff_classes = args.data.num_diff_classes
self.max_difficulty = args.data.max_diff
if args.data.mapper_class:
self._init_mapper_idx(args)
if args.data.descriptor_class:
self._init_descriptor_idx(args)
if args.data.circle_size_class:
self.num_cs_classes = args.data.num_cs_classes
def encode_style(self, beatmap_id: int) -> int:
"""Converts beatmap id into token id."""
return self.beatmap_idx.get(beatmap_id, self.num_classes - 1)
@property
def style_unk(self) -> int:
"""Gets the unknown style value token id."""
return self.num_classes - 1
def encode_diff(self, diff: float) -> int:
"""Converts difficulty value into token id."""
return self.num_classes + np.clip(int(diff * (self.num_diff_classes - 2) / self.max_difficulty), 0, self.num_diff_classes - 2)
@property
def diff_unk(self) -> int:
"""Gets the unknown difficulty value token id."""
return self.num_classes + self.num_diff_classes - 1
def encode_mapper(self, beatmap_id: int) -> int:
"""Converts beatmap id into token id."""
user_id = self.beatmap_mapper.get(beatmap_id, -1)
return self.encode_mapper_id(user_id)
def encode_mapper_id(self, user_id: int) -> int:
"""Converts user id into token id."""
mapper_idx = self.mapper_idx.get(user_id, self.num_mapper_classes - 1)
return self.num_classes + self.num_diff_classes + mapper_idx
@property
def mapper_unk(self) -> int:
"""Gets the unknown mapper value token id."""
return self.num_classes + self.num_diff_classes + self.num_mapper_classes - 1
def encode_descriptor(self, beatmap_id: int) -> list[int]:
"""Converts beatmap id into token ids."""
return [self.encode_descriptor_idx(descriptor_idx) for descriptor_idx in self.beatmap_descriptors.get(beatmap_id, [self.num_descriptor_classes - 1])]
def encode_descriptor_name(self, descriptor: str) -> int:
"""Converts descriptor into token id."""
descriptor_idx = self.descriptor_idx.get(descriptor, self.num_descriptor_classes)
return self.encode_descriptor_idx(descriptor_idx)
def encode_descriptor_idx(self, descriptor_idx: int) -> int:
"""Converts descriptor idx into token id."""
return self.num_classes + self.num_diff_classes + self.num_mapper_classes + descriptor_idx
@property
def descriptor_unk(self) -> int:
"""Gets the unknown descriptor value token id."""
return self.num_classes + self.num_diff_classes + self.num_mapper_classes + self.num_descriptor_classes - 1
def encode_cs(self, cs: float) -> int:
"""Converts circle size value into token id."""
return (self.num_classes + self.num_diff_classes + self.num_mapper_classes + self.num_descriptor_classes
+ np.clip(int(cs * (self.num_cs_classes - 2) / 10), 0, self.num_cs_classes - 2))
@property
def cs_unk(self) -> int:
"""Gets the unknown circle size value token id."""
return (self.num_classes + self.num_diff_classes + self.num_mapper_classes + self.num_descriptor_classes
+ self.num_cs_classes - 1)
@property
def num_tokens(self) -> int:
"""Gets the number of tokens."""
return (self.num_classes + self.num_diff_classes + self.num_mapper_classes + self.num_descriptor_classes
+ self.num_cs_classes)
def _init_beatmap_idx(self, args: DiffusionTrainConfig) -> None:
"""Initializes and caches the beatmap index."""
if args is None or "train_dataset_path" not in args.data:
return
path = Path(args.data.train_dataset_path)
cache_path = path / "beatmap_idx.pickle"
if cache_path.exists():
with open(path / "beatmap_idx.pickle", "rb") as f:
self.beatmap_idx = pickle.load(f)
self.num_classes = max(self.beatmap_idx.values()) + 2
return
print("Caching beatmap index...")
highest_index = -1
for track in tqdm(path.iterdir()):
if not track.is_dir():
continue
metadata_file = track / "metadata.json"
with open(metadata_file) as f:
metadata = json.load(f)
for beatmap_name in metadata["Beatmaps"]:
beatmap_metadata = metadata["Beatmaps"][beatmap_name]
index = beatmap_metadata["Index"]
self.beatmap_idx[beatmap_metadata["BeatmapId"]] = index
highest_index = max(highest_index, index)
self.num_classes = highest_index + 2
with open(cache_path, "wb") as f:
pickle.dump(self.beatmap_idx, f)
def _init_mapper_idx(self, args):
""""Indexes beatmap mappers and mapper idx."""
if args is None or "mappers_path" not in args.data:
raise ValueError("mappers_path not found in args")
path = Path(args.data.mappers_path)
if not path.exists():
raise ValueError(f"mappers_path {path} not found")
# Load JSON data from file
with open(path, 'r') as file:
data = json.load(file)
# Populate beatmap_mapper
for item in data:
self.beatmap_mapper[item['id']] = item['user_id']
# Get unique user_ids from beatmap_mapper values
unique_user_ids = list(set(self.beatmap_mapper.values()))
# Create mapper_idx
self.mapper_idx = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
self.num_mapper_classes = len(unique_user_ids) + 1
def _init_descriptor_idx(self, args):
""""Indexes beatmap descriptors and descriptor idx."""
if args is None or "descriptors_path" not in args.data:
raise ValueError("descriptors_path not found in args")
path = Path(args.data.descriptors_path)
if not path.exists():
raise ValueError(f"descriptors_path {path} not found")
# The descriptors file is a CSV file with the following format:
# beatmap_id,descriptor_name
with open(path, 'r') as file:
data = file.readlines()
# Populate descriptor_idx
for line in data:
_, descriptor_name = line.strip().split(',')
if descriptor_name not in self.descriptor_idx:
self.descriptor_idx[descriptor_name] = len(self.descriptor_idx)
# Populate beatmap_descriptors
for line in data:
beatmap_id_str, descriptor_name = line.strip().split(',')
beatmap_id = int(beatmap_id_str)
descriptor_idx = self.descriptor_idx[descriptor_name]
if beatmap_id not in self.beatmap_descriptors:
self.beatmap_descriptors[beatmap_id] = []
self.beatmap_descriptors[beatmap_id].append(descriptor_idx)
self.num_descriptor_classes = len(self.descriptor_idx) + 1
def state_dict(self):
return {
"beatmap_idx": self.beatmap_idx,
"num_classes": self.num_classes,
"num_diff_classes": self.num_diff_classes,
"max_difficulty": self.max_difficulty,
"beatmap_mapper": self.beatmap_mapper,
"mapper_idx": self.mapper_idx,
"num_mapper_classes": self.num_mapper_classes,
"beatmap_descriptors": self.beatmap_descriptors,
"descriptor_idx": self.descriptor_idx,
"num_descriptor_classes": self.num_descriptor_classes,
"num_cs_classes": self.num_cs_classes,
}
def load_state_dict(self, state_dict):
self.beatmap_idx = state_dict["beatmap_idx"]
self.num_classes = state_dict["num_classes"]
self.num_diff_classes = state_dict["num_diff_classes"]
self.max_difficulty = state_dict["max_difficulty"]
if "beatmap_mapper" in state_dict:
self.beatmap_mapper = state_dict["beatmap_mapper"]
if "mapper_idx" in state_dict:
self.mapper_idx = state_dict["mapper_idx"]
if "num_mapper_classes" in state_dict:
self.num_mapper_classes = state_dict["num_mapper_classes"]
if "beatmap_descriptors" in state_dict:
self.beatmap_descriptors = state_dict["beatmap_descriptors"]
if "descriptor_idx" in state_dict:
self.descriptor_idx = state_dict["descriptor_idx"]
if "num_descriptor_classes" in state_dict:
self.num_descriptor_classes = state_dict["num_descriptor_classes"]
if "num_cs_classes" in state_dict:
self.num_cs_classes = state_dict["num_cs_classes"]