osu_mapper2 / osudiffusion /data_loading.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
import math
import os.path
import pickle
import random
from collections.abc import Callable
from datetime import timedelta
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Optional
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import IterableDataset
import tqdm
from .positional_embedding import offset_sequence_embedding
from .positional_embedding import position_sequence_embedding
from .positional_embedding import timestep_embedding
from slider import Position
from slider.beatmap import Beatmap
from slider.beatmap import HitObject
from slider.beatmap import Slider
from slider.beatmap import Spinner
from slider.curve import Catmull
from slider.curve import Linear
from slider.curve import MultiBezier
from slider.curve import Perfect
playfield_size = torch.tensor((512, 384))
feature_size = 19
def create_datapoint(time: timedelta, pos: Position, datatype: int) -> torch.Tensor:
features = torch.zeros(19)
features[0] = pos.x
features[1] = pos.y
features[2] = time.total_seconds() * 1000
features[datatype + 3] = 1
return features
def repeat_type(repeat: int) -> int:
if repeat < 4:
return repeat - 1
elif repeat % 2 == 0:
return 3
else:
return 4
def append_control_points(
datapoints: list[torch.Tensor],
slider: Slider,
datatype: int,
duration: timedelta,
):
control_point_count = len(slider.curve.points)
for i in range(1, control_point_count - 1):
time = slider.time + i / (control_point_count - 1) * duration
pos = slider.curve.points[i]
datapoints.append(create_datapoint(time, pos, datatype))
def get_data(hitobj: HitObject) -> torch.Tensor:
if isinstance(hitobj, Slider) and len(hitobj.curve.points) < 100:
datapoints = [
create_datapoint(
hitobj.time,
hitobj.position,
5 if hitobj.new_combo else 4,
),
]
assert hitobj.repeat >= 1
duration: timedelta = (hitobj.end_time - hitobj.time) / hitobj.repeat
if isinstance(hitobj.curve, Linear):
append_control_points(datapoints, hitobj, 9, duration)
elif isinstance(hitobj.curve, Catmull):
append_control_points(datapoints, hitobj, 8, duration)
elif isinstance(hitobj.curve, Perfect):
append_control_points(datapoints, hitobj, 7, duration)
elif isinstance(hitobj.curve, MultiBezier):
control_point_count = len(hitobj.curve.points)
for i in range(1, control_point_count - 1):
time = hitobj.time + i / (control_point_count - 1) * duration
pos = hitobj.curve.points[i]
if pos == hitobj.curve.points[i + 1]:
datapoints.append(create_datapoint(time, pos, 9))
elif pos != hitobj.curve.points[i - 1]:
datapoints.append(create_datapoint(time, pos, 6))
datapoints.append(
create_datapoint(hitobj.time + duration, hitobj.curve.points[-1], 10),
)
slider_end_pos = hitobj.curve(1)
datapoints.append(
create_datapoint(
hitobj.end_time,
slider_end_pos,
11 + repeat_type(hitobj.repeat),
),
)
return torch.stack(datapoints, 0)
if isinstance(hitobj, Spinner):
return torch.stack(
(
create_datapoint(hitobj.time, hitobj.position, 2),
create_datapoint(hitobj.end_time, hitobj.position, 3),
),
0,
)
return create_datapoint(
hitobj.time,
hitobj.position,
1 if hitobj.new_combo else 0,
).unsqueeze(0)
def beatmap_to_sequence(beatmap: Beatmap) -> torch.Tensor:
# Get the hit objects
hit_objects = beatmap.hit_objects(stacking=False)
data_chunks = [get_data(ho) for ho in hit_objects]
sequence = torch.concatenate(data_chunks, 0)
sequence = torch.swapaxes(sequence, 0, 1)
return sequence.float()
def random_flip(seq: torch.Tensor) -> torch.Tensor:
if random.random() < 0.5:
seq[0] = 512 - seq[0]
if random.random() < 0.5:
seq[1] = 384 - seq[1]
return seq
def calc_distances(seq: torch.Tensor) -> torch.Tensor:
offset = torch.roll(seq[:2, :], 1, 1)
offset[0, 0] = 256
offset[1, 0] = 192
seq_d = torch.linalg.vector_norm(seq[:2, :] - offset, ord=2, dim=0)
return seq_d
def split_and_process_sequence(
seq: torch.Tensor,
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], int]:
seq_d = calc_distances(seq)
# Augment and normalize positions for diffusion
seq_x = random_flip(seq[:2, :]) / playfield_size.unsqueeze(1)
seq_o = seq[2, :]
seq_c = torch.concatenate(
[
timestep_embedding(seq_d, 128).T,
seq[3:, :],
],
0,
)
return (seq_x, seq_o, seq_c), seq.shape[1]
def split_and_process_sequence_no_augment(
seq: torch.Tensor,
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], int]:
seq_d = calc_distances(seq)
# Augment and normalize positions for diffusion
seq_x = seq[:2, :] / playfield_size.to(seq.device).unsqueeze(1)
seq_o = seq[2, :]
seq_c = torch.concatenate(
[
timestep_embedding(seq_d, 128).T,
seq[3:, :],
],
0,
)
return (seq_x, seq_o, seq_c), seq.shape[1]
def load_and_process_beatmap(beatmap: Beatmap):
seq = beatmap_to_sequence(beatmap)
return split_and_process_sequence(seq)
def window_and_relative_time(seq, s, e):
seq_x, seq_o, seq_c = seq
x = seq_x[:, s:e]
# Obscure the absolute time by normalizing to zero and adding a random offset between zero and the max period
# We do this to make sure the offset embedding utilizes the full range of values, which is also the case when sampling the model
o = seq_o[s:e] - seq_o[s] + random.random() * 100000
c = seq_c[:, s:e]
return x, o, c
class BeatmapDatasetIterable:
__slots__ = (
"beatmap_files",
"beatmap_idx",
"seq_len",
"stride",
"index",
"current_idx",
"current_seq",
"current_seq_len",
"seq_index",
"seq_func",
"win_func",
)
def __init__(
self,
beatmap_files: list[str],
seq_len: int,
stride: int,
seq_func: Callable,
win_func: Callable,
):
self.beatmap_files = beatmap_files
self.seq_len = seq_len
self.stride = stride
self.index = 0
self.current_idx = 0
self.current_seq = None
self.current_seq_len = -1
self.seq_index = 0
self.seq_func = seq_func
self.win_func = win_func
def __iter__(self) -> "BeatmapDatasetIterable":
return self
def __next__(self) -> tuple[any, int]:
while (
self.current_seq is None
or self.seq_index + self.seq_len > self.current_seq_len
):
if self.index >= len(self.beatmap_files):
raise StopIteration
# Load the beatmap from file
beatmap_path = self.beatmap_files[self.index]
beatmap = Beatmap.from_path(beatmap_path)
self.current_idx = int(os.path.basename(beatmap_path)[:6])
self.current_seq, self.current_seq_len = self.seq_func(beatmap)
self.seq_index = random.randint(0, self.stride - 1)
self.index += 1
# Return the preprocessed hit objects as a sequence of overlapping windows
window = self.win_func(
self.current_seq,
self.seq_index,
self.seq_index + self.seq_len,
)
self.seq_index += self.stride
return window, self.current_idx
class InterleavingBeatmapDatasetIterable:
__slots__ = ("workers", "cycle_length", "index")
def __init__(
self,
beatmap_files: list[str],
iterable_factory: Callable,
cycle_length: int,
):
per_worker = int(math.ceil(len(beatmap_files) / float(cycle_length)))
self.workers = [
iterable_factory(
beatmap_files[
i * per_worker: min(len(beatmap_files), (i + 1) * per_worker)
]
)
for i in range(cycle_length)
]
self.cycle_length = cycle_length
self.index = 0
def __iter__(self) -> "InterleavingBeatmapDatasetIterable":
return self
def __next__(self) -> tuple[any, int]:
num = len(self.workers)
for _ in range(num):
try:
self.index = self.index % len(self.workers)
item = self.workers[self.index].__next__()
self.index += 1
return item
except StopIteration:
self.workers.remove(self.workers[self.index])
raise StopIteration
class BeatmapDataset(IterableDataset):
def __init__(
self,
dataset_path: str,
start: int,
end: int,
iterable_factory: Callable,
cycle_length: int = 1,
shuffle: bool = False,
beatmap_files: Optional[list[str]] = None,
):
super(BeatmapDataset).__init__()
self.dataset_path = dataset_path
self.start = start
self.end = end
self.iterable_factory = iterable_factory
self.cycle_length = cycle_length
self.shuffle = shuffle
self.beatmap_files = beatmap_files
def _get_beatmap_files(self) -> list[str]:
if self.beatmap_files is not None:
return self.beatmap_files
# Get a list of all beatmap files in the dataset path in the track index range between start and end
beatmap_files = []
track_names = ["Track" + str(i).zfill(5) for i in range(self.start, self.end)]
for track_name in track_names:
for beatmap_file in os.listdir(
os.path.join(self.dataset_path, track_name, "beatmaps"),
):
beatmap_files.append(
os.path.join(
self.dataset_path,
track_name,
"beatmaps",
beatmap_file,
),
)
return beatmap_files
def __iter__(self) -> InterleavingBeatmapDatasetIterable | BeatmapDatasetIterable:
beatmap_files = self._get_beatmap_files()
if self.shuffle:
random.shuffle(beatmap_files)
if self.cycle_length > 1:
return InterleavingBeatmapDatasetIterable(
beatmap_files,
self.iterable_factory,
self.cycle_length,
)
return self.iterable_factory(beatmap_files)
# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id: int) -> None:
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
overall_start = dataset.start
overall_end = dataset.end
# configure the dataset to only process the split workload
per_worker = int(
math.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
)
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)
def get_beatmap_idx(name) -> dict[int, int]:
p = Path(__file__).with_name(name)
with p.open("rb") as f:
beatmap_idx = pickle.load(f)
return beatmap_idx
def get_beatmap_files(name: str, data_path: str) -> list[PurePosixPath]:
p = Path(name)
with p.open("rb") as f:
relative_beatmap_files = pickle.load(f)
beatmap_files = [PurePosixPath(data_path, *PureWindowsPath(f).parts) for f in relative_beatmap_files]
return beatmap_files
class BeatmapDatasetIterableFactory:
__slots__ = ("seq_len", "stride", "seq_func", "win_func")
def __init__(self, seq_len, stride, seq_func, win_func):
self.seq_len = seq_len
self.stride = stride
self.seq_func = seq_func
self.win_func = win_func
def __call__(self, *args, **kwargs):
beatmap_files = args[0]
return BeatmapDatasetIterable(
beatmap_files=beatmap_files,
seq_len=self.seq_len,
stride=self.stride,
seq_func=self.seq_func,
win_func=self.win_func,
)
class CachedDataset(Dataset):
__slots__ = "cached_data"
def __init__(self, cached_data):
self.cached_data = cached_data
def __getitem__(self, index):
return self.cached_data[index]
def __len__(self):
return len(self.cached_data)
def cache_dataset(
out_path: str,
dataset_path: str,
start: int,
end: int,
iterable_factory: Callable,
cycle_length=1,
beatmap_files: Optional[list[str]] = None,
):
dataset = BeatmapDataset(
dataset_path=dataset_path,
start=start,
end=end,
iterable_factory=iterable_factory,
cycle_length=cycle_length,
shuffle=False,
beatmap_files=beatmap_files,
)
print("Caching dataset...")
cached_data = []
for datum in tqdm.tqdm(dataset):
cached_data.append(datum)
torch.save(cached_data, out_path)
def get_cached_data_loader(
data_path: str,
batch_size: int = 1,
num_workers: int = 0,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
):
cached_data = torch.load(data_path)
dataset = CachedDataset(cached_data)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
persistent_workers=num_workers > 0,
shuffle=shuffle,
)
return dataloader
def get_data_loader(
dataset_path: str,
start: int,
end: int,
iterable_factory: Callable,
cycle_length=1,
batch_size: int = 1,
num_workers: int = 0,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
beatmap_files: Optional[list[str]] = None,
) -> DataLoader:
dataset = BeatmapDataset(
dataset_path=dataset_path,
start=start,
end=end,
iterable_factory=iterable_factory,
cycle_length=cycle_length,
shuffle=shuffle,
beatmap_files=beatmap_files,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
worker_init_fn=worker_init_fn,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
persistent_workers=num_workers > 0,
)
return dataloader
def main(args):
dataloader = get_data_loader(
dataset_path=args.data_path,
start=0,
end=16291,
iterable_factory=BeatmapDatasetIterableFactory(
128,
16,
load_and_process_beatmap,
window_and_relative_time,
),
cycle_length=1,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=False,
drop_last=True,
)
if args.mode == "plotfirst":
import matplotlib.pyplot as plt
for (x, o, c), y in dataloader:
x = torch.swapaxes(x, 1, 2) # (N, T, C)
c = torch.swapaxes(c, 1, 2) # (N, T, E)
print(x.shape, o.shape, c.shape, y.shape)
batch_pos_emb = position_sequence_embedding(x * playfield_size, 128)
print(batch_pos_emb.shape)
batch_offset_emb = offset_sequence_embedding(o / 10, 128)
print(batch_offset_emb.shape)
print(y)
for j in range(args.batch_size):
fig, axs = plt.subplots(3, figsize=(5, 20))
axs[0].imshow(batch_pos_emb[j])
axs[1].imshow(batch_offset_emb[j])
axs[2].imshow(c[j])
print(y[j])
plt.show()
break
elif args.mode == "benchmark":
for _ in tqdm.tqdm(dataloader, total=7000, smoothing=0.01):
pass
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True)
parser.add_argument("--mode", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--num-workers", type=int, default=0)
args = parser.parse_args()
main(args)