diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..68b4bc682299b9d6e3ce4d821f2517dcedb01175 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,36 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +osuT5/inference/vale.mp3 filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7be5fc7f47d5db027d120b8024982df93db95b74 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +--- +license: mit +--- diff --git a/checkpoint/custom_checkpoint_0.pkl b/checkpoint/custom_checkpoint_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..988605f8e7389646283ed995474e8bc4b8dcf8be --- /dev/null +++ b/checkpoint/custom_checkpoint_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0494fdd396142b4a2919c0ab913502c9335746959156a08e82c2647235e07853 +size 564880 diff --git a/checkpoint/pytorch_model.bin b/checkpoint/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..fcaad97057fa0182ce2abb233829b603fdd86015 --- /dev/null +++ b/checkpoint/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a12b6c312590efbdf5d7acaff6d8537e8ad1728737eebb43d0a43d5a4b3b5a3a +size 377860126 diff --git a/configs/inference.yaml b/configs/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f25db862e41c338097301f559788a7e0c1e183e --- /dev/null +++ b/configs/inference.yaml @@ -0,0 +1,65 @@ +model: + name: 'google/t5-v1_1-small' + spectrogram: + sample_rate: 16000 + hop_length: 128 + n_fft: 1024 + n_mels: 388 + do_style_embed: false + input_features: false + +model_path: './checkpoint' +audio_path: '' # Path to input audio +total_duration_ms: 0 # Total duration of audio in milliseconds, 0 for full audio +output_path: '' # Path to output directory +bpm: 120 # Beats per minute of input audio +offset: 0 # Start of beat, in miliseconds, from the beginning of input audio +resnap_objects: false # Resnap objects beat timing ticks, requires accurate BPM and offset +slider_multiplier: 1.7 # Multiplier for slider velocity +title: '' # Song title +artist: '' # Song artist +beatmap_path: '' # Path to .osu file which will be remapped +other_beatmap_path: '' # Path to .osu file of other beatmap in the mapset to use as reference +beatmap_id: -1 # Beatmap ID to use as style +difficulty: -1 # Difficulty star rating to map +creator: '' # Beatmap creator +version: '' # Beatmap version +full_set: true # Generate full mapset +set_difficulties: 5 # Number of difficulties to generate. + +# Diffusion settings +generate_positions: true # Use diffusion to generate object positions +diff_ckpt: './osudiffusion/DiT-B-0700000.pt' # Path to checkpoint for diffusion model +diff_refine_ckpt: '' # Path to checkpoint for refining diffusion model + +diffusion: + style_id: 1451282 # Style ID to use for diffusion + num_sampling_steps: 100 # Number of sampling steps + cfg_scale: 1 # Scale of classifier-free guidance + num_classes: 52670 # Number of classes stored in the model + beatmap_idx: 'osudiffusion/beatmap_idx.pickle' # Path to beatmap index + use_amp: true # Use automatic mixed precision + refine_iters: 10 # Number of refinement iterations + seq_len: 128 # Sequence length + model: 'DiT-B' # Model architecture + + +data: # Data settings + src_seq_len: 640 + tgt_seq_len: 480 + sample_rate: ${model.spectrogram.sample_rate} + hop_length: ${model.spectrogram.hop_length} + sequence_stride: 1 # Fraction of audio sequence length to shift inference window + center_pad_decoder: false # Center pad decoder input + add_pre_tokens: true + special_token_len: 2 + diff_token_index: 0 + style_token_index: -1 + max_pre_token_len: 4 + add_gd_context: false # Prefix the decoder with tokens of another beatmap in the mapset + +hydra: + job: + chdir: False + run: + dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/configs/model/model.yaml b/configs/model/model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf4032f450537b817313e0adbaa851db84d11341 --- /dev/null +++ b/configs/model/model.yaml @@ -0,0 +1,8 @@ +input_features: false +do_style_embed: true + +spectrogram: + sample_rate: 16000 + hop_length: 128 + n_fft: 1024 + n_mels: 388 \ No newline at end of file diff --git a/configs/model/t5_base.yaml b/configs/model/t5_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11df27bdd5051b1fd22237bafd3faeabd2f277cf --- /dev/null +++ b/configs/model/t5_base.yaml @@ -0,0 +1,15 @@ +defaults: + - model + - _self_ + +name: 'google/t5-v1_1-base' +overwrite: + dropout_rate: 0.0 + +spectrogram: + sample_rate: 16000 + hop_length: 128 + n_fft: 1024 + n_mels: 388 +do_style_embed: false +input_features: false \ No newline at end of file diff --git a/configs/model/t5_small.yaml b/configs/model/t5_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7153b5e17b6303896051504352060711ffad07e8 --- /dev/null +++ b/configs/model/t5_small.yaml @@ -0,0 +1,10 @@ +defaults: + - model + - _self_ + +name: 'google/t5-v1_1-small' +overwrite: + dropout_rate: 0.0 + +spectrogram: + n_mels: 512 \ No newline at end of file diff --git a/configs/model/t5_small_v4.yaml b/configs/model/t5_small_v4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad99de2d58f2fcf74524cdb7d4a5ac5303f81610 --- /dev/null +++ b/configs/model/t5_small_v4.yaml @@ -0,0 +1,7 @@ +defaults: + - model + - _self_ + +name: 'google/t5-v1_1-small' +overwrite: + dropout_rate: 0.0 \ No newline at end of file diff --git a/configs/model/t5_small_v9.yaml b/configs/model/t5_small_v9.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b1d231f901130901a72f1f08679daa8bf44b9a2d --- /dev/null +++ b/configs/model/t5_small_v9.yaml @@ -0,0 +1,9 @@ +defaults: + - model + - _self_ + +do_style_embed: false + +name: 'google/t5-v1_1-small' +overwrite: + dropout_rate: 0.0 \ No newline at end of file diff --git a/configs/model/whisper_base.yaml b/configs/model/whisper_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ddc035d1e2fecb78974f56df286323ef4d34c71 --- /dev/null +++ b/configs/model/whisper_base.yaml @@ -0,0 +1,6 @@ +defaults: + - model + - _self_ + +name: 'openai/whisper-base' +input_features: true \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2b34cbb7059673540129b206fcf37a6e0bd65d22 --- /dev/null +++ b/inference.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig +from slider import Beatmap + +from osudiffusion import DiT_models +from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipeline +from osuT5.tokenizer import Tokenizer +from osuT5.utils import get_model + + +def get_args_from_beatmap(args: DictConfig): + if args.beatmap_path is None or args.beatmap_path == "": + return + + beatmap_path = Path(args.beatmap_path) + + if not beatmap_path.is_file(): + raise FileNotFoundError(f"Beatmap file {beatmap_path} not found.") + + beatmap = Beatmap.from_path(beatmap_path) + args.audio_path = beatmap_path.parent / beatmap.audio_filename + args.output_path = beatmap_path.parent + args.bpm = beatmap.bpm_max() + args.offset = min(tp.offset.total_seconds() * 1000 for tp in beatmap.timing_points) + args.slider_multiplier = beatmap.slider_multiplier + args.title = beatmap.title + args.artist = beatmap.artist + args.beatmap_id = beatmap.beatmap_id if args.beatmap_id == -1 else args.beatmap_id + args.diffusion.style_id = beatmap.beatmap_id if args.diffusion.style_id == -1 else args.diffusion.style_id + args.difficulty = float(beatmap.stars()) if args.difficulty == -1 else args.difficulty + + +def find_model(ckpt_path, args: DictConfig, device): + assert Path(ckpt_path).exists(), f"Could not find DiT checkpoint at {ckpt_path}" + checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) + if "ema" in checkpoint: # supports checkpoints from train.py + checkpoint = checkpoint["ema"] + + model = DiT_models[args.diffusion.model]( + num_classes=args.diffusion.num_classes, + context_size=19 - 3 + 128, + ).to(device) + model.load_state_dict(checkpoint) + model.eval() # important! + return model + + +@hydra.main(config_path="configs", config_name="inference", version_base="1.1") +def main(args: DictConfig): + get_args_from_beatmap(args) + + torch.set_grad_enabled(False) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_path = Path(args.model_path) + model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device) + tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl") + + tokenizer = Tokenizer() + tokenizer.load_state_dict(tokenizer_state) + + model = get_model(args, tokenizer) + model.load_state_dict(model_state) + model.eval() + model.to(device) + + preprocessor = Preprocessor(args) + audio = preprocessor.load(args.audio_path) + sequences = preprocessor.segment(audio) + total_duration_ms = len(audio) / 16000 * 1000 + args.total_duration_ms = total_duration_ms + + + + + + generated_maps = [] + generated_positions = [] + diffs = [] + + + if args.full_set: + for i in range(args.set_difficulties): + diffs.append(3 + i * (7 - 3) / (args.set_difficulties - 1)) + + print(diffs) + for diff in diffs: + print(f"Generating difficulty {diff}") + args.difficulty = diff + pipeline = Pipeline(args, tokenizer) + events = pipeline.generate(model, sequences) + generated_maps.append(events) + else: + pipeline = Pipeline(args, tokenizer) + events = pipeline.generate(model, sequences) + generated_maps.append(events) + + + + if args.generate_positions: + model = find_model(args.diff_ckpt, args, device) + refine_model = find_model(args.diff_refine_ckpt, args, device) if len(args.diff_refine_ckpt) > 0 else None + diffusion_pipeline = DiffisionPipeline(args.diffusion) + for events in generated_maps: + events = diffusion_pipeline.generate(model, events, refine_model) + generated_positions.append(events) + else: + generated_positions = generated_maps + + postprocessor = Postprocessor(args) + postprocessor.generate(generated_positions) + + +if __name__ == "__main__": + main() diff --git a/osuT5/__init__.py b/osuT5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/osuT5/__pycache__/__init__.cpython-311.pyc b/osuT5/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd93948de7b49aa95e33d094c5c3b55cd5eaa89 Binary files /dev/null and b/osuT5/__pycache__/__init__.cpython-311.pyc differ diff --git a/osuT5/__pycache__/__init__.cpython-39.pyc b/osuT5/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eab382ceefa4e7cbf09c8332b4420137d6f4be29 Binary files /dev/null and b/osuT5/__pycache__/__init__.cpython-39.pyc differ diff --git a/osuT5/dataset/__init__.py b/osuT5/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7645c8e15e0f331784b037affdd215a899e855cc --- /dev/null +++ b/osuT5/dataset/__init__.py @@ -0,0 +1 @@ +from .osu_parser import OsuParser diff --git a/osuT5/dataset/__pycache__/__init__.cpython-311.pyc b/osuT5/dataset/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45a2d1e0de60113a35ec69efc1503081c60feaa9 Binary files /dev/null and b/osuT5/dataset/__pycache__/__init__.cpython-311.pyc differ diff --git a/osuT5/dataset/__pycache__/__init__.cpython-39.pyc b/osuT5/dataset/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9b0039447edf04ac256fe40f1c4823c7b802dcd Binary files /dev/null and b/osuT5/dataset/__pycache__/__init__.cpython-39.pyc differ diff --git a/osuT5/dataset/__pycache__/data_utils.cpython-311.pyc b/osuT5/dataset/__pycache__/data_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e98b5345db6a9a01c26b8f3c75ab2352b463edc Binary files /dev/null and b/osuT5/dataset/__pycache__/data_utils.cpython-311.pyc differ diff --git a/osuT5/dataset/__pycache__/data_utils.cpython-39.pyc b/osuT5/dataset/__pycache__/data_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0f814cb04607c88143e85f6d28a0f47a9c8b81 Binary files /dev/null and b/osuT5/dataset/__pycache__/data_utils.cpython-39.pyc differ diff --git a/osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc b/osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..720ef0bf6ba5e31b7b4b45811d6b877d79bc9398 Binary files /dev/null and b/osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc differ diff --git a/osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc b/osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e558376212a629f2743e01d4d0eb84e38e10532 Binary files /dev/null and b/osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc differ diff --git a/osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc b/osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24535e071030c33e999a6291304cadb82c53c3a3 Binary files /dev/null and b/osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc differ diff --git a/osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc b/osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b028a91791fce823b4352ea6ae600080797a6f61 Binary files /dev/null and b/osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc differ diff --git a/osuT5/dataset/data_utils.py b/osuT5/dataset/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..80b5dc5710fbec246f589a5943deb0f19f34ce75 --- /dev/null +++ b/osuT5/dataset/data_utils.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Optional + +import numpy as np +from pydub import AudioSegment + +import numpy.typing as npt + +from osuT5.tokenizer import Event, EventType + +MILISECONDS_PER_SECOND = 1000 + + +def load_audio_file(file: Path, sample_rate: int) -> npt.NDArray: + """Load an audio file as a numpy time-series array + + The signals are resampled, converted to mono channel, and normalized. + + Args: + file: Path to audio file. + sample_rate: Sample rate to resample the audio. + + Returns: + samples: Audio time series. + """ + print(file) + audio = AudioSegment.from_file(file, format="mp3") + audio = audio.set_frame_rate(sample_rate) + audio = audio.set_channels(1) + samples = np.array(audio.get_array_of_samples()).astype(np.float32) + samples *= 1.0 / np.max(np.abs(samples)) + return samples + + +def update_event_times(events: list[Event], event_times: list[float], end_time: Optional[float] = None): + non_timed_events = [ + EventType.BEZIER_ANCHOR, + EventType.PERFECT_ANCHOR, + EventType.CATMULL_ANCHOR, + EventType.RED_ANCHOR, + ] + timed_events = [ + EventType.CIRCLE, + EventType.SPINNER, + EventType.SPINNER_END, + EventType.SLIDER_HEAD, + EventType.LAST_ANCHOR, + EventType.SLIDER_END, + ] + + start_index = len(event_times) + end_index = len(events) + ct = 0 if len(event_times) == 0 else event_times[-1] + for i in range(start_index, end_index): + event = events[i] + if event.type == EventType.TIME_SHIFT: + ct = event.value + event_times.append(ct) + + # Interpolate time for control point events + # T-D-Start-D-CP-D-CP-T-D-LCP-T-D-End + # 1-1-1-----1-1--1-1--7-7--7--9-9-9-- + # 1-1-1-----3-3--5-5--7-7--7--9-9-9-- + ct = end_time if end_time is not None else event_times[-1] + interpolate = False + for i in range(end_index - 1, start_index - 1, -1): + event = events[i] + + if event.type in timed_events: + interpolate = False + + if event.type in non_timed_events: + interpolate = True + + if not interpolate: + ct = event_times[i] + continue + + if event.type not in non_timed_events: + event_times[i] = ct + continue + + # Find the time of the first timed event and the number of control points between + j = i + count = 0 + t = ct + while j >= 0: + event2 = events[j] + if event2.type == EventType.TIME_SHIFT: + t = event_times[j] + break + if event2.type in non_timed_events: + count += 1 + j -= 1 + if i < 0: + t = 0 + + # Interpolate the time + ct = (ct - t) / (count + 1) * count + t + event_times[i] = ct diff --git a/osuT5/dataset/osu_parser.py b/osuT5/dataset/osu_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..54d30bf390f9ed83ab92428693361429fcde785c --- /dev/null +++ b/osuT5/dataset/osu_parser.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from datetime import timedelta + +import numpy as np +import numpy.typing as npt +from slider import Beatmap, Circle, Slider, Spinner +from slider.curve import Linear, Catmull, Perfect, MultiBezier + +from osuT5.tokenizer import Event, EventType, Tokenizer + + +class OsuParser: + def __init__(self, tokenizer: Tokenizer) -> None: + dist_range = tokenizer.event_range[EventType.DISTANCE] + self.dist_min = dist_range.min_value + self.dist_max = dist_range.max_value + + def parse(self, beatmap: Beatmap) -> list[Event]: + # noinspection PyUnresolvedReferences + """Parse an .osu beatmap. + + Each hit object is parsed into a list of Event objects, in order of its + appearance in the beatmap. In other words, in ascending order of time. + + Args: + beatmap: Beatmap object parsed from an .osu file. + + Returns: + events: List of Event object lists. + + Example:: + >>> beatmap = [ + "64,80,11000,1,0", + "100,100,16000,2,0,B|200:200|250:200|250:200|300:150,2" + ] + >>> events = parse(beatmap) + >>> print(events) + [ + Event(EventType.TIME_SHIFT, 11000), Event(EventType.DISTANCE, 36), Event(EventType.CIRCLE), + Event(EventType.TIME_SHIFT, 16000), Event(EventType.DISTANCE, 42), Event(EventType.SLIDER_HEAD), + Event(EventType.TIME_SHIFT, 16500), Event(EventType.DISTANCE, 141), Event(EventType.BEZIER_ANCHOR), + Event(EventType.TIME_SHIFT, 17000), Event(EventType.DISTANCE, 50), Event(EventType.BEZIER_ANCHOR), + Event(EventType.TIME_SHIFT, 17500), Event(EventType.DISTANCE, 10), Event(EventType.BEZIER_ANCHOR), + Event(EventType.TIME_SHIFT, 18000), Event(EventType.DISTANCE, 64), Event(EventType.LAST _ANCHOR), + Event(EventType.TIME_SHIFT, 20000), Event(EventType.DISTANCE, 11), Event(EventType.SLIDER_END) + ] + """ + hit_objects = beatmap.hit_objects(stacking=False) + last_pos = np.array((256, 192)) + events = [] + + for hit_object in hit_objects: + if isinstance(hit_object, Circle): + last_pos = self._parse_circle(hit_object, events, last_pos) + elif isinstance(hit_object, Slider): + last_pos = self._parse_slider(hit_object, events, last_pos) + elif isinstance(hit_object, Spinner): + last_pos = self._parse_spinner(hit_object, events) + + return events + + def _clip_dist(self, dist: int) -> int: + """Clip distance to valid range.""" + return int(np.clip(dist, self.dist_min, self.dist_max)) + + def _parse_circle(self, circle: Circle, events: list[Event], last_pos: npt.NDArray) -> npt.NDArray: + """Parse a circle hit object. + + Args: + circle: Circle object. + events: List of events to add to. + last_pos: Last position of the hit objects. + + Returns: + pos: Position of the circle. + """ + time = int(circle.time.total_seconds() * 1000) + pos = np.array(circle.position) + dist = self._clip_dist(np.linalg.norm(pos - last_pos)) + + events.append(Event(EventType.TIME_SHIFT, time)) + events.append(Event(EventType.DISTANCE, dist)) + if circle.new_combo: + events.append(Event(EventType.NEW_COMBO)) + events.append(Event(EventType.CIRCLE)) + + return pos + + def _parse_slider(self, slider: Slider, events: list[Event], last_pos: npt.NDArray) -> npt.NDArray: + """Parse a slider hit object. + + Args: + slider: Slider object. + events: List of events to add to. + last_pos: Last position of the hit objects. + + Returns: + pos: Last position of the slider. + """ + # Ignore sliders which are too big + if len(slider.curve.points) >= 100: + return last_pos + + time = int(slider.time.total_seconds() * 1000) + pos = np.array(slider.position) + dist = self._clip_dist(np.linalg.norm(pos - last_pos)) + last_pos = pos + + events.append(Event(EventType.TIME_SHIFT, time)) + events.append(Event(EventType.DISTANCE, dist)) + if slider.new_combo: + events.append(Event(EventType.NEW_COMBO)) + events.append(Event(EventType.SLIDER_HEAD)) + + duration: timedelta = (slider.end_time - slider.time) / slider.repeat + control_point_count = len(slider.curve.points) + + def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray: + for i in range(1, control_point_count - 1): + last_pos = add_anchor_time_dist(i, last_pos) + events.append(Event(event_type)) + + return last_pos + + def add_anchor_time_dist(i: int, last_pos: npt.NDArray) -> npt.NDArray: + time = int((slider.time + i / (control_point_count - 1) * duration).total_seconds() * 1000) + pos = np.array(slider.curve.points[i]) + dist = self._clip_dist(np.linalg.norm(pos - last_pos)) + last_pos = pos + + events.append(Event(EventType.TIME_SHIFT, time)) + events.append(Event(EventType.DISTANCE, dist)) + + return last_pos + + if isinstance(slider.curve, Linear): + last_pos = append_control_points(EventType.RED_ANCHOR, last_pos) + elif isinstance(slider.curve, Catmull): + last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos) + elif isinstance(slider.curve, Perfect): + last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos) + elif isinstance(slider.curve, MultiBezier): + for i in range(1, control_point_count - 1): + if slider.curve.points[i] == slider.curve.points[i + 1]: + last_pos = add_anchor_time_dist(i, last_pos) + events.append(Event(EventType.RED_ANCHOR)) + elif slider.curve.points[i] != slider.curve.points[i - 1]: + last_pos = add_anchor_time_dist(i, last_pos) + events.append(Event(EventType.BEZIER_ANCHOR)) + + last_pos = add_anchor_time_dist(control_point_count - 1, last_pos) + events.append(Event(EventType.LAST_ANCHOR)) + + time = int(slider.end_time.total_seconds() * 1000) + pos = np.array(slider.curve(1)) + dist = self._clip_dist(np.linalg.norm(pos - last_pos)) + last_pos = pos + + events.append(Event(EventType.TIME_SHIFT, time)) + events.append(Event(EventType.DISTANCE, dist)) + events.append(Event(EventType.SLIDER_END)) + + return last_pos + + def _parse_spinner(self, spinner: Spinner, events: list[Event]) -> npt.NDArray: + """Parse a spinner hit object. + + Args: + spinner: Spinner object. + events: List of events to add to. + + Returns: + pos: Last position of the spinner. + """ + time = int(spinner.time.total_seconds() * 1000) + events.append(Event(EventType.TIME_SHIFT, time)) + events.append(Event(EventType.SPINNER)) + + time = int(spinner.end_time.total_seconds() * 1000) + events.append(Event(EventType.TIME_SHIFT, time)) + events.append(Event(EventType.SPINNER_END)) + + return np.array((256, 192)) diff --git a/osuT5/inference/__init__.py b/osuT5/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e01497b63524a7efb18fc182acaa95ed22ac2a3 --- /dev/null +++ b/osuT5/inference/__init__.py @@ -0,0 +1,4 @@ +from .pipeline import * +from .preprocessor import * +from .postprocessor import * +from .diffusion_pipeline import * diff --git a/osuT5/inference/__pycache__/__init__.cpython-311.pyc b/osuT5/inference/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b42ae821160fe8b0183e22113ec2bcd51e1397c0 Binary files /dev/null and b/osuT5/inference/__pycache__/__init__.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/__init__.cpython-39.pyc b/osuT5/inference/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac3306f7b9f1e5c759558846d2b0ab8feab70236 Binary files /dev/null and b/osuT5/inference/__pycache__/__init__.cpython-39.pyc differ diff --git a/osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc b/osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dff5511ceb609443382d53268c65de1854b547d5 Binary files /dev/null and b/osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/path_approximator.cpython-311.pyc b/osuT5/inference/__pycache__/path_approximator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6ebb1c2884a8878f36b6f1b4cd420c2ac357128 Binary files /dev/null and b/osuT5/inference/__pycache__/path_approximator.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/path_approximator.cpython-39.pyc b/osuT5/inference/__pycache__/path_approximator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aedf44f9357c3cc29e9edebdb66e7f945284368 Binary files /dev/null and b/osuT5/inference/__pycache__/path_approximator.cpython-39.pyc differ diff --git a/osuT5/inference/__pycache__/pipeline.cpython-311.pyc b/osuT5/inference/__pycache__/pipeline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce6d65b9b949641039126952c62bc71f8912859 Binary files /dev/null and b/osuT5/inference/__pycache__/pipeline.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/pipeline.cpython-39.pyc b/osuT5/inference/__pycache__/pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4952a1ca01ec8aa6f924f2c7ba9e9705deb67ad8 Binary files /dev/null and b/osuT5/inference/__pycache__/pipeline.cpython-39.pyc differ diff --git a/osuT5/inference/__pycache__/postprocessor.cpython-311.pyc b/osuT5/inference/__pycache__/postprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d39d18688fa7043f9d0b9486bec8db287e18388c Binary files /dev/null and b/osuT5/inference/__pycache__/postprocessor.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/postprocessor.cpython-39.pyc b/osuT5/inference/__pycache__/postprocessor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4682f3f0e976f5bd758a4ee85cc203795cfc2247 Binary files /dev/null and b/osuT5/inference/__pycache__/postprocessor.cpython-39.pyc differ diff --git a/osuT5/inference/__pycache__/preprocessor.cpython-311.pyc b/osuT5/inference/__pycache__/preprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca92fdfda43f609862398c96ba5abb9294693b0 Binary files /dev/null and b/osuT5/inference/__pycache__/preprocessor.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/preprocessor.cpython-39.pyc b/osuT5/inference/__pycache__/preprocessor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79a7364f36898ba91e34d4e801e4f8baa8392fc1 Binary files /dev/null and b/osuT5/inference/__pycache__/preprocessor.cpython-39.pyc differ diff --git a/osuT5/inference/__pycache__/slider_path.cpython-311.pyc b/osuT5/inference/__pycache__/slider_path.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5f907b13d336f9c3f9dd223968e45e1d20398bf Binary files /dev/null and b/osuT5/inference/__pycache__/slider_path.cpython-311.pyc differ diff --git a/osuT5/inference/__pycache__/slider_path.cpython-39.pyc b/osuT5/inference/__pycache__/slider_path.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3ed9476b08e27eaf6099a02805a4ea24f7e921e Binary files /dev/null and b/osuT5/inference/__pycache__/slider_path.cpython-39.pyc differ diff --git a/osuT5/inference/diffusion_pipeline.py b/osuT5/inference/diffusion_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..957a22759c8884f00e66c46393fdb1e208c40b14 --- /dev/null +++ b/osuT5/inference/diffusion_pipeline.py @@ -0,0 +1,214 @@ +import pickle +from pathlib import Path + +import torch +from omegaconf import DictConfig +from tqdm import tqdm + +from osudiffusion import timestep_embedding +from osudiffusion import repeat_type +from osudiffusion import create_diffusion +from osudiffusion import DiT +from osuT5.dataset.data_utils import update_event_times +from osuT5.tokenizer import Event, EventType + + +def get_beatmap_idx(path) -> dict[int, int]: + p = Path(path) + with p.open("rb") as f: + beatmap_idx = pickle.load(f) + return beatmap_idx + + +class DiffisionPipeline(object): + def __init__(self, args: DictConfig): + """Model inference stage that generates positions for distance events.""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_sampling_steps = args.num_sampling_steps + self.cfg_scale = args.cfg_scale + self.seq_len = args.seq_len + self.num_classes = args.num_classes + self.beatmap_idx = get_beatmap_idx(args.beatmap_idx) + self.style_id = args.style_id + self.refine_iters = args.refine_iters + self.use_amp = args.use_amp + + if self.style_id in self.beatmap_idx: + self.class_label = self.beatmap_idx[self.style_id] + else: + print(f"Beatmap ID {self.style_id} not found in dataset, using default style.") + self.class_label = self.num_classes + + def generate(self, model: DiT, events: list[Event], refine_model: DiT = None) -> list[Event]: + """Generate position events for distance events in the Event list. + + Args: + model: Trained model to use for inference. + events: List of Event objects with distance events. + refine_model: Optional model to refine the generated positions. + + Returns: + events: List of Event objects with position events. + """ + + seq_o, seq_c, seq_len, seq_indices = self.events_to_sequence(events) + + seq_o = seq_o - seq_o[0] # Normalize to relative time + print(f"seq len {seq_len}") + + diffusion = create_diffusion( + str(self.num_sampling_steps), + noise_schedule="squaredcos_cap_v2", + ) + + # Create banded matrix attention mask for increased sequence length + attn_mask = torch.full((seq_len, seq_len), True, dtype=torch.bool, device=self.device) + for i in range(seq_len): + attn_mask[max(0, i - self.seq_len): min(seq_len, i + self.seq_len), i] = False + + class_labels = [self.class_label] + + # Create sampling noise: + n = len(class_labels) + z = torch.randn(n, 2, seq_len, device=self.device) + o = seq_o.repeat(n, 1).to(self.device) + c = seq_c.repeat(n, 1, 1).to(self.device) + y = torch.tensor(class_labels, device=self.device) + + # Setup classifier-free guidance: + z = torch.cat([z, z], 0) + o = torch.cat([o, o], 0) + c = torch.cat([c, c], 0) + y_null = torch.tensor([self.num_classes] * n, device=self.device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(o=o, c=c, y=y, cfg_scale=self.cfg_scale, attn_mask=attn_mask) + + def to_positions(samples): + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples *= torch.tensor((512, 384), device=self.device).repeat(n, 1).unsqueeze(2) + return samples.cpu() + + # Sample images: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, + z.shape, + z, + clip_denoised=True, + model_kwargs=model_kwargs, + progress=True, + device=self.device, + ) + + if refine_model is not None: + # Refine result with refine model + for _ in tqdm(range(self.refine_iters)): + t = torch.tensor([0] * samples.shape[0], device=self.device) + with torch.no_grad(): + out = diffusion.p_sample( + model.forward_with_cfg, + samples, + t, + clip_denoised=True, + model_kwargs=model_kwargs, + ) + samples = out["sample"] + + positions = to_positions(samples) + return self.events_with_pos(events, positions.squeeze(0), seq_indices) + + @staticmethod + def events_to_sequence(events: list[Event]) -> tuple[torch.Tensor, torch.Tensor, int, dict[int, int]]: + # Calculate the time of every event and interpolate time for control point events + event_times = [] + update_event_times(events, event_times) + + # Calculate the number of repeats for each slider end event + # Convert to vectorized form for osu-diffusion + nc_types = [EventType.CIRCLE, EventType.SLIDER_HEAD] + event_index = { + EventType.CIRCLE: 0, + EventType.SPINNER: 2, + EventType.SPINNER_END: 3, + EventType.SLIDER_HEAD: 4, + EventType.BEZIER_ANCHOR: 6, + EventType.PERFECT_ANCHOR: 7, + EventType.CATMULL_ANCHOR: 8, + EventType.RED_ANCHOR: 9, + EventType.LAST_ANCHOR: 10, + EventType.SLIDER_END: 11, + } + + seq_indices = {} + indices = [] + data_chunks = [] + distance = 0 + new_combo = False + head_time = 0 + last_anchor_time = 0 + for i, event in enumerate(events): + indices.append(i) + if event.type == EventType.DISTANCE: + distance = event.value + elif event.type == EventType.NEW_COMBO: + new_combo = True + elif event.type in event_index: + time = event_times[i] + index = event_index[event.type] + + # Handle NC index offset + if event.type in nc_types and new_combo: + index += 1 + new_combo = False + + # Add slider end repeats index offset + if event.type == EventType.SLIDER_END: + span_duration = last_anchor_time - head_time + total_duration = time - head_time + repeats = max(int(round(total_duration / span_duration)), 1) if span_duration > 0 else 1 + index += repeat_type(repeats) + elif event.type == EventType.SLIDER_HEAD: + head_time = time + elif event.type == EventType.LAST_ANCHOR: + last_anchor_time = time + + features = torch.zeros(18) + features[0] = time + features[1] = distance + features[index + 2] = 1 + data_chunks.append(features) + + for j in indices: + seq_indices[j] = len(data_chunks) - 1 + indices = [] + + seq = torch.stack(data_chunks, 0) + seq = torch.swapaxes(seq, 0, 1) + seq_o = seq[0, :] + seq_d = seq[1, :] + seq_c = torch.concatenate( + [ + timestep_embedding(seq_d, 128).T, + seq[2:, :], + ], + 0, + ) + + return seq_o, seq_c, seq.shape[1], seq_indices + + + @staticmethod + def events_with_pos(events: list[Event], sampled_seq: torch.Tensor, seq_indices: dict[int, int]) -> list[Event]: + new_events = [] + for i, event in enumerate(events): + if event.type == EventType.DISTANCE: + try: + index = seq_indices[i] + pos_x = sampled_seq[0, index].item() + pos_y = sampled_seq[1, index].item() + new_events.append(Event(EventType.POS_X, int(round(pos_x)))) + new_events.append(Event(EventType.POS_Y, int(round(pos_y)))) + except KeyError: + print(f"Warning: Key {i} not found in seq_indices. Skipping event.") + else: + new_events.append(event) + return new_events diff --git a/osuT5/inference/path_approximator.py b/osuT5/inference/path_approximator.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1aafffd5859f80ed61a34ebbb428dc56bf949c --- /dev/null +++ b/osuT5/inference/path_approximator.py @@ -0,0 +1,253 @@ +import numpy as np + +BEZIER_TOLERANCE = 0.25 +CATMULL_DETAIL = 50 +CIRCULAR_ARC_TOLERANCE = 0.1 + + +length_squared = lambda x: np.inner(x, x) + + +def approximate_bezier(control_points: np.ndarray) -> np.ndarray: + return approximate_b_spline(control_points) + + +def approximate_b_spline(control_points: np.ndarray, p: int = 0) -> np.ndarray: + output = [] + n = len(control_points) - 1 + + if n < 0: + return output + + to_flatten = [] + free_buffers = [] + + points = control_points.copy() + + if 0 < p < n: + for i in range(n - p): + sub_bezier = np.empty((p + 1, 2)) + sub_bezier[0] = points[i] + + for j in range(p - 1): + sub_bezier[j + 1] = points[i + 1] + + for k in range(1, p - j): + l = np.min((k, n - p - i)) + points[i + k] = (l * points[i + k] + points[i + k + 1]) / (l + 1) + + sub_bezier[p] = points[i + 1] + to_flatten.append(sub_bezier) + + to_flatten.append(points[(n - p) :]) + to_flatten.reverse() + else: + p = n + to_flatten.append(points) + + subdivision_buffer1 = np.empty([p + 1, 2]) + subdivision_buffer2 = np.empty([p * 2 + 1, 2]) + + left_child = subdivision_buffer2 + + while len(to_flatten) > 0: + parent = to_flatten.pop() + + if bezier_is_flat_enough(parent): + bezier_approximate( + parent, + output, + subdivision_buffer1, + subdivision_buffer2, + p + 1, + ) + + free_buffers.append(parent) + continue + + right_child = ( + free_buffers.pop() if len(free_buffers) > 0 else np.empty([p + 1, 2]) + ) + bezier_subdivide(parent, left_child, right_child, subdivision_buffer1, p + 1) + + for i in range(p + 1): + parent[i] = left_child[i] + + to_flatten.append(right_child) + to_flatten.append(parent) + + output.append(control_points[n].copy()) + return np.vstack(output) + + +def approximate_catmull(control_points: np.ndarray) -> list[np.ndarray]: + result = [] + + for i in range(len(control_points) - 1): + v1 = control_points[i - 1] if i > 0 else control_points[i] + v2 = control_points[i] + v3 = control_points[i + 1] if i < len(control_points) - 1 else v2 + v2 - v1 + v4 = control_points[i + 2] if i < len(control_points) - 2 else v3 + v3 - v2 + + for c in range(CATMULL_DETAIL): + result.append(catmull_find_point(v1, v2, v3, v4, c / CATMULL_DETAIL)) + result.append(catmull_find_point(v1, v2, v3, v4, (c + 1) / CATMULL_DETAIL)) + + return result + + +def approximate_circular_arc(control_points: np.ndarray) -> list[np.ndarray]: + a = control_points[0] + b = control_points[1] + c = control_points[2] + + aSq = length_squared(b - c) + bSq = length_squared(a - c) + cSq = length_squared(a - b) + + if np.isclose(aSq, 0) or np.isclose(bSq, 0) or np.isclose(cSq, 0): + return [] + + s = aSq * (bSq + cSq - aSq) + t = bSq * (aSq + cSq - bSq) + u = cSq * (aSq + bSq - cSq) + + sum = s + t + u + + if np.isclose(sum, 0): + return [] + + centre = (s * a + t * b + u * c) / sum + dA = a - centre + dC = c - centre + + r = np.linalg.norm(dA) + + theta_start = np.arctan2(dA[1], dA[0]) + theta_end = np.arctan2(dC[1], dC[0]) + + while theta_end < theta_start: + theta_end += 2 * np.pi + + direction = 1 + theta_range = theta_range = theta_end - theta_start + + ortho_ato_c = c - a + ortho_ato_c = np.array([ortho_ato_c[1], -ortho_ato_c[0]]) + if np.dot(ortho_ato_c, b - a) < 0: + direction = -direction + theta_range = 2 * np.pi - theta_range + + amount_points = ( + 2 + if 2 * r <= CIRCULAR_ARC_TOLERANCE + else int( + max( + 2, + np.ceil(theta_range / (2 * np.arccos(1 - CIRCULAR_ARC_TOLERANCE / r))), + ), + ) + ) + + output = [] + + for i in range(amount_points): + fract = i / (amount_points - 1) + theta = theta_start + direction * fract * theta_range + o = np.array([np.cos(theta), np.sin(theta)]) * r + output.append(centre + o) + + return output + + +def approximate_linear(control_points: np.ndarray) -> list[np.ndarray]: + result = [] + + for c in control_points: + result.append(c.copy()) + + return result + + +def bezier_is_flat_enough(control_points: np.ndarray) -> bool: + for i in range(1, len(control_points) - 1): + p = control_points[i - 1] - 2 * control_points[i] + control_points[i + 1] + if length_squared(p) > BEZIER_TOLERANCE * BEZIER_TOLERANCE * 4: + return False + + return True + + +def bezier_subdivide( + control_points: np.ndarray, + left: np.ndarray, + right: np.ndarray, + subdivision_buffer: np.ndarray, + count: int, +) -> None: + midpoints = subdivision_buffer + + for i in range(count): + midpoints[i] = control_points[i] + + for i in range(count): + left[i] = midpoints[0].copy() + right[count - i - 1] = midpoints[count - i - 1] + + for j in range(count - i - 1): + midpoints[j] = (midpoints[j] + midpoints[j + 1]) / 2 + + +def bezier_approximate( + control_points: np.ndarray, + output: list[np.ndarray], + subdivision_buffer1: np.ndarray, + subdivision_buffer2: np.ndarray, + count: int, +) -> None: + left = subdivision_buffer2 + right = subdivision_buffer1 + + bezier_subdivide(control_points, left, right, subdivision_buffer1, count) + + for i in range(count - 1): + left[count + i] = right[i + 1] + + output.append(control_points[0].copy()) + + for i in range(1, count - 1): + index = 2 * i + p = 0.25 * (left[index - 1] + 2 * left[index] + left[index + 1]) + output.append(p.copy()) + + +def catmull_find_point( + vec1: np.ndarray, + vec2: np.ndarray, + vec3: np.ndarray, + vec4: np.ndarray, + t: float, +) -> np.ndarray: + t2 = t * t + t3 = t * t2 + + result = np.array( + [ + 0.5 + * ( + 2 * vec2[0] + + (-vec1[0] + vec3[0]) * t + + (2 * vec1[0] - 5 * vec2[0] + 4 * vec3[0] - vec4[0]) * t2 + + (-vec1[0] + 3 * vec2[0] - 3 * vec3[0] + vec4[0]) * t3 + ), + 0.5 + * ( + 2 * vec2[1] + + (-vec1[1] + vec3[1]) * t + + (2 * vec1[1] - 5 * vec2[1] + 4 * vec3[1] - vec4[1]) * t2 + + (-vec1[1] + 3 * vec2[1] - 3 * vec3[1] + vec4[1]) * t3 + ), + ], + ) + + return result diff --git a/osuT5/inference/pipeline.py b/osuT5/inference/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ad9db140f2aa00d64a3f86d302da15a9b1492f --- /dev/null +++ b/osuT5/inference/pipeline.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from pathlib import Path + +import torch +import torch.nn.functional as F +from slider import Beatmap +from tqdm import tqdm + +from omegaconf import DictConfig + +from osuT5.dataset import OsuParser +from osuT5.dataset.data_utils import update_event_times +from osuT5.tokenizer import Event, EventType, Tokenizer +from osuT5.model import OsuT + +MILISECONDS_PER_SECOND = 1000 +MILISECONDS_PER_STEP = 10 + +def top_k_sampling(logits, k): + top_k_logits, top_k_indices = torch.topk(logits, k) + top_k_probs = F.softmax(top_k_logits, dim=-1) + sampled_index = torch.multinomial(top_k_probs, 1) + sampled_token = top_k_indices.gather(-1, sampled_index) + return sampled_token + +def preprocess_event(event, frame_time): + if event.type == EventType.TIME_SHIFT: + event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP)) + return event + +class Pipeline(object): + def __init__(self, args: DictConfig, tokenizer: Tokenizer): + """Model inference stage that processes sequences.""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.tokenizer = tokenizer + self.tgt_seq_len = args.data.tgt_seq_len + self.frame_seq_len = args.data.src_seq_len - 1 + self.frame_size = args.model.spectrogram.hop_length + self.sample_rate = args.model.spectrogram.sample_rate + self.samples_per_sequence = self.frame_seq_len * self.frame_size + self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride) + self.miliseconds_per_sequence = self.samples_per_sequence * MILISECONDS_PER_SECOND / self.sample_rate + self.miliseconds_per_stride = self.sequence_stride * MILISECONDS_PER_SECOND / self.sample_rate + self.beatmap_id = args.beatmap_id + self.difficulty = args.difficulty + self.center_pad_decoder = args.data.center_pad_decoder + self.special_token_len = args.data.special_token_len + self.diff_token_index = args.data.diff_token_index + self.style_token_index = args.data.style_token_index + self.max_pre_token_len = args.data.max_pre_token_len + self.add_pre_tokens = args.data.add_pre_tokens + self.add_gd_context = args.data.add_gd_context + self.bpm = args.bpm + self.offset = args.offset + self.total_duration_ms = args.total_duration_ms + + print(f"Configuration: {args}") + + if self.add_gd_context: + other_beatmap_path = Path(args.other_beatmap_path) + + if not other_beatmap_path.is_file(): + raise FileNotFoundError(f"Beatmap file {other_beatmap_path} not found.") + + other_beatmap = Beatmap.from_path(other_beatmap_path) + self.other_beatmap_id = other_beatmap.beatmap_id + self.other_difficulty = float(other_beatmap.stars()) + parser = OsuParser(tokenizer) + self.other_events = parser.parse(other_beatmap) + self.other_events, self.other_event_times = self._prepare_events(self.other_events) + + def _calculate_time_shifts(self, bpm: float, duration_ms: float, tick_rate: int, offset: float = 0) -> list[float]: + """Calculate EventType.TIME_SHIFT events based on song's BPM and tick rate.""" + events = [] + ms_per_beat = 60000 / bpm # 60000 ms per minute + ms_per_tick = ms_per_beat / tick_rate + num_ticks = int(duration_ms // ms_per_tick) + + for i in range(num_ticks): + events.append(float(int(i * ms_per_tick + offset)) ) + + return events + + def generate_events(self, model, frames, tokens, encoder_outputs, beatmap_idx, total_steps): + temperature = 0.9 + k = 10 # top-k sampling + + for _ in range(total_steps): + out = model.forward( + frames=frames, + decoder_input_ids=tokens, + decoder_attention_mask=tokens.ne(self.tokenizer.pad_id), + encoder_outputs=encoder_outputs, + beatmap_idx=beatmap_idx, + ) + encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions) + logits = out.logits + logits = logits[:, -1, :] / temperature + logits = self._filter(logits, 0.9) + probabilities = F.softmax(logits, dim=-1) + next_tokens = top_k_sampling(probabilities, k) + + tokens = torch.cat([tokens, next_tokens], dim=-1) + + eos_in_sentence = next_tokens == self.tokenizer.eos_id + if eos_in_sentence.all(): + break + + return tokens + + def generate(self, model: OsuT, sequences: torch.Tensor, top_k: int = 50) -> list[Event]: + """ + Generate a list of Event object lists and their timestamps given source sequences. + + Args: + model: Trained model to use for inference. + sequences: A list of batched source sequences. + top_k: Number of top tokens to use for top-k sampling. + + Returns: + events: List of Event object lists. + event_times: Corresponding event times of Event object lists in milliseconds. + """ + events = [] + event_times = [] + temperature = 0.95 + + idx_dict = self.tokenizer.beatmap_idx + beatmap_idx = torch.tensor([idx_dict.get(self.beatmap_id, 6666)], dtype=torch.long, device=self.device) + style_token = self.tokenizer.encode_style(self.beatmap_id) if self.beatmap_id in idx_dict else self.tokenizer.style_unk + diff_token = self.tokenizer.encode_diff(self.difficulty) if self.difficulty != -1 else self.tokenizer.diff_unk + + special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device) + special_tokens[:, self.diff_token_index] = diff_token + special_tokens[:, self.style_token_index] = style_token + + if self.add_gd_context: + other_style_token = self.tokenizer.encode_style(self.other_beatmap_id) if self.other_beatmap_id in idx_dict else self.tokenizer.style_unk + other_special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device) + other_special_tokens[:, self.diff_token_index] = self.tokenizer.encode_diff(self.other_difficulty) + other_special_tokens[:, self.style_token_index] = other_style_token + else: + other_special_tokens = torch.empty((1, 0), dtype=torch.long, device=self.device) + + for sequence_index, frames in enumerate(tqdm(sequences)): + # Get tokens of previous frame + frame_time = sequence_index * self.miliseconds_per_stride + prev_events = self._get_events_time_range( + events, event_times, frame_time - self.miliseconds_per_sequence, frame_time) if self.add_pre_tokens else [] + post_events = self._get_events_time_range( + events, event_times, frame_time, frame_time + self.miliseconds_per_sequence) + + prev_tokens = self._encode(prev_events, frame_time) + post_tokens = self._encode(post_events, frame_time) + post_token_length = post_tokens.shape[1] + + if 0 <= self.max_pre_token_len < prev_tokens.shape[1]: + prev_tokens = prev_tokens[:, -self.max_pre_token_len:] + + # Get prefix tokens + prefix = torch.cat([special_tokens, prev_tokens], dim=-1) + if self.center_pad_decoder: + prefix = F.pad(prefix, (self.tgt_seq_len // 2 - prefix.shape[1], 0), value=self.tokenizer.pad_id) + prefix_length = prefix.shape[1] + + + max_retries = 5 + attempt = 0 + result = [] + + while attempt < max_retries and not result: + attempt += 1 + try: + # Reset tokens + tokens = torch.tensor([[self.tokenizer.sos_id]], dtype=torch.long, device=self.device) + tokens = torch.cat([prefix, tokens, post_tokens], dim=-1) + + # Ensure frames are properly reset for each retry + retry_frames = frames.clone().to(self.device).unsqueeze(0) + encoder_outputs = None + + while tokens.shape[-1] < self.tgt_seq_len: + out = model.forward( + frames=retry_frames, + decoder_input_ids=tokens, + decoder_attention_mask=tokens.ne(self.tokenizer.pad_id), + encoder_outputs=encoder_outputs, + #beatmap_idx=beatmap_idx, + ) + encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions) + + logits = out.logits[:, -1, :] + logits = logits / temperature + logits = self._filter(logits, top_p=0.9, top_k=60) + probabilities = F.softmax(logits, dim=-1) + next_tokens = torch.multinomial(probabilities, 1) + + tokens = torch.cat([tokens, next_tokens], dim=-1) + + eos_in_sentence = next_tokens == self.tokenizer.eos_id + if eos_in_sentence.all(): + break + + predicted_tokens = tokens[:, prefix_length + 1 + post_token_length:] + result = self._decode(predicted_tokens[0], frame_time) + + # if no new combo in result, retry; + if len(result) > 10 and not any(event.type == EventType.NEW_COMBO for event in result): + #print("No new combo in result; retrying...") + result = [] + + + except Exception as e: + #print(f"Attempt {attempt} encountered an error: {e}") + result = [] # Ensure result is empty to trigger retry + + events += result + + self._update_event_times(events, event_times, frame_time) + + return events + + def _prepare_events(self, events: list[Event]) -> tuple[list[Event], list[float]]: + """Pre-process raw list of events for inference. Calculates event times and removes redundant time shifts.""" + ct = 0 + event_times = [] + for event in events: + if event.type == EventType.TIME_SHIFT: + ct = event.value + event_times.append(ct) + + # Loop through the events in reverse to remove any time shifts that occur before anchor events + delete_next_time_shift = False + for i in range(len(events) - 1, -1, -1): + if events[i].type == EventType.TIME_SHIFT and delete_next_time_shift: + delete_next_time_shift = False + del events[i] + del event_times[i] + continue + elif events[i].type in [EventType.BEZIER_ANCHOR, EventType.PERFECT_ANCHOR, EventType.CATMULL_ANCHOR, + EventType.RED_ANCHOR]: + delete_next_time_shift = True + + # duplicate events 3 times + + + return events, event_times + + def _get_events_time_range(self, events: list[Event], event_times: list[float], start_time: float, end_time: float): + # Look from the end of the list + s = 0 + for i in range(len(event_times) - 1, -1, -1): + if event_times[i] < start_time: + s = i + 1 + break + e = 0 + for i in range(len(event_times) - 1, -1, -1): + if event_times[i] < end_time: + e = i + 1 + break + return events[s:e] + + def _update_event_times(self, events: list[Event], event_times: list[float], frame_time: float): + update_event_times(events, event_times, frame_time + self.miliseconds_per_sequence) + + + def _encode(self, events: list[Event], frame_time: float) -> torch.Tensor: + try: + + tokens = torch.empty((1, len(events)), dtype=torch.long) + for i, event in enumerate(events): + if event.type == EventType.TIME_SHIFT: + event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP)) + tokens[0, i] = self.tokenizer.encode(event) + return tokens.to(self.device) + except Exception as e: + #print(f"Error encoding events: {events}") + #print(e) + return torch.empty((1, 0), dtype=torch.long, device=self.device) + def _decode(self, tokens: torch.Tensor, frame_time: float) -> list[Event]: + """Converts a list of tokens into Event objects and converts to absolute time values. + + Args: + tokens: List of tokens. + frame time: Start time of current source sequence. + + Returns: + events: List of Event objects. + """ + events = [] + for token in tokens: + if token == self.tokenizer.eos_id: + break + + try: + event = self.tokenizer.decode(token.item()) + except: + continue + + if event.type == EventType.TIME_SHIFT: + event.value = frame_time + event.value * MILISECONDS_PER_STEP + + events.append(event) + + return events + + def _filter(self, logits: torch.Tensor, top_p: float = 0.75, top_k: int = 1, filter_value: float = -float("Inf")) -> torch.Tensor: + """Filter a distribution of logits using nucleus (top-p) and/or top-k filtering. + """ + logits = top_k_logits(logits, top_k) if top_k > 0 else logits + + if 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + + return logits +def top_k_logits(logits, k): + """ + Keep only the top-k tokens with highest probabilities. + + Args: + logits: Logits distribution of shape (batch size, vocabulary size). + k: Number of top tokens to keep. + + Returns: + logits with non-top-k elements set to negative infinity. + """ + values, indices = torch.topk(logits, k) + min_values = values[:, -1].unsqueeze(-1).expand_as(logits) + return torch.where(logits < min_values, torch.full_like(logits, float("-Inf")), logits) diff --git a/osuT5/inference/postprocessor.py b/osuT5/inference/postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..596c65e0f6ac2b732133efdd8a4006a1f8c4ed92 --- /dev/null +++ b/osuT5/inference/postprocessor.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import dataclasses +import os +import pathlib +import uuid +from string import Template +import zipfile +import numpy as np +from omegaconf import DictConfig +import time as t +from osuT5.inference.slider_path import SliderPath +from osuT5.tokenizer import Event, EventType + +OSU_FILE_EXTENSION = ".osu" +OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu") +STEPS_PER_MILLISECOND = 0.1 + + +@dataclasses.dataclass +class BeatmapConfig: + # General + audio_filename: str = "" + + # Metadata + title: str = "" + title_unicode: str = "" + artist: str = "" + artist_unicode: str = "" + creator: str = "" + version: str = "" + + # Difficulty + hp_drain_rate: float = 5 + circle_size: float = 4 + overall_difficulty: float = 8 + approach_rate: float = 9 + slider_multiplier: float = 1.8 + + +def calculate_coordinates(last_pos, dist, num_samples, playfield_size): + # Generate a set of angles + angles = np.linspace(0, 2*np.pi, num_samples) + + # Calculate the x and y coordinates for each angle + x_coords = last_pos[0] + dist * np.cos(angles) + y_coords = last_pos[1] + dist * np.sin(angles) + + # Combine the x and y coordinates into a list of tuples + coordinates = list(zip(x_coords, y_coords)) + + # Filter out coordinates that are outside the playfield + coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]] + + if len(coordinates) == 0: + return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)] + + return coordinates + + +def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray: + eps = 1e-4 + lr = 1 + t = 1 + for i in range(100): + grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm( + slider_path.position_at(t - eps) - pos, + ) + t -= lr * grad + + if grad == 0 or t < 0 or t > 1: + break + + return np.clip(t, 0, 1) + + + +def quantize_to_beat(time, bpm, offset): + """Quantize a given time to the nearest beat based on the BPM and offset.""" + # tick rate is 1/4 + #tick_rate = 0.25 + # tick rate is 1/8 + # tick_rate = 0.125 + # tick rate is 1/2 + #tick_rate = 0.5 + tick_rate = 0.5 + beats_per_minute = bpm + beats_per_second = beats_per_minute / 60.0 + milliseconds_per_beat = 1000 / beats_per_second + quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset + return quantized_time + +def quantize_to_beat_again(time, bpm, offset): + """Quantize a given time to the nearest beat based on the BPM and offset.""" + # tick rate is 1/4 + #tick_rate = 0.25 + # tick rate is 1/8 + # tick_rate = 0.125 + # tick rate is 1/2 + #tick_rate = 0.5 + tick_rate = 0.25 + beats_per_minute = bpm + beats_per_second = beats_per_minute / 60.0 + milliseconds_per_beat = 1000 / beats_per_second + quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset + return quantized_time + +def move_to_next_tick(time, bpm): + """Move to the next tick based on the BPM and offset.""" + tick_rate = 0.25 + beats_per_minute = bpm + beats_per_second = beats_per_minute / 60.0 + milliseconds_per_beat = 1000 / beats_per_second + quantized_time = time + milliseconds_per_beat * tick_rate + return quantized_time + +def move_to_prev_tick(time, bpm): + """Move to the next tick based on the BPM and offset.""" + tick_rate = 0.25 + beats_per_minute = bpm + beats_per_second = beats_per_minute / 60.0 + milliseconds_per_beat = 1000 / beats_per_second + quantized_time = time - milliseconds_per_beat * tick_rate + return quantized_time + +def adjust_hit_objects(hit_objects, bpm, offset): + """Adjust the timing of hit objects to align with beats based on BPM and offset.""" + adjusted_hit_objects = [] + adjusted_times = [] + to_be_adjusted = [] + for hit_object in hit_objects: + hit_type = hit_object.type + if hit_type == EventType.TIME_SHIFT: + time = quantize_to_beat(hit_object.value, bpm, offset) + + + if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END): + time = move_to_next_tick(time, bpm) + adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time)) + adjusted_times.append(int(time)) + + else: + adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time)) + adjusted_times.append(int(time)) + else: + adjusted_hit_objects.append(hit_object) + + + + return adjusted_hit_objects + + + +class Postprocessor(object): + def __init__(self, args: DictConfig): + """Postprocessing stage that converts a list of Event objects to a beatmap file.""" + self.curve_type_shorthand = { + "B": "Bezier", + "P": "PerfectCurve", + "C": "Catmull", + } + + self.output_path = args.output_path + self.audio_path = args.audio_path + self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0] + self.beatmap_config = BeatmapConfig( + title=str(f"{self.audio_filename} ({args.title})"), + artist=str(args.artist), + title_unicode=str(args.title), + artist_unicode=str(args.artist), + audio_filename=pathlib.Path(args.audio_path).name, + slider_multiplier=float(args.slider_multiplier), + creator=str(args.creator), + version=str(args.version), + ) + self.offset = args.offset + self.beat_length = 60000 / args.bpm + self.slider_multiplier = self.beatmap_config.slider_multiplier + self.bpm = args.bpm + self.resnap_objects = args.resnap_objects + + def generate(self, generated_positions: list[Event]): + """Generate a beatmap file. + + Args: + events: List of Event objects. + + Returns: + None. An .osu file will be generated. + """ + processed_events = [] + + for events in generated_positions: + # adjust hit objects to align with 1/4 beats + if self.resnap_objects: + events = adjust_hit_objects(events, self.bpm, self.offset) + + + hit_object_strings = [] + time = 0 + dist = 0 + x = 256 + y = 192 + has_pos = False + new_combo = 0 + ho_info = [] + anchor_info = [] + + timing_point_strings = [ + f"{self.offset},{self.beat_length},4,2,0,100,1,0" + ] + + for event in events: + hit_type = event.type + + if hit_type == EventType.TIME_SHIFT: + time = event.value + continue + elif hit_type == EventType.DISTANCE: + # Find a point which is dist away from the last point but still within the playfield + dist = event.value + coordinates = calculate_coordinates((x, y), dist, 500, (512, 384)) + pos = coordinates[np.random.randint(len(coordinates))] + x, y = pos + continue + elif hit_type == EventType.POS_X: + x = event.value + has_pos = True + continue + elif hit_type == EventType.POS_Y: + y = event.value + has_pos = True + continue + elif hit_type == EventType.NEW_COMBO: + new_combo = 4 + continue + + if hit_type == EventType.CIRCLE: + hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0") + ho_info = [] + + elif hit_type == EventType.SPINNER: + ho_info = [time, new_combo] + + elif hit_type == EventType.SPINNER_END and len(ho_info) == 2: + hit_object_strings.append( + f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}" + ) + ho_info = [] + + elif hit_type == EventType.SLIDER_HEAD: + ho_info = [x, y, time, new_combo] + anchor_info = [] + + elif hit_type == EventType.BEZIER_ANCHOR: + anchor_info.append(('B', x, y)) + + elif hit_type == EventType.PERFECT_ANCHOR: + anchor_info.append(('P', x, y)) + + elif hit_type == EventType.CATMULL_ANCHOR: + anchor_info.append(('C', x, y)) + + elif hit_type == EventType.RED_ANCHOR: + anchor_info.append(('B', x, y)) + anchor_info.append(('B', x, y)) + + elif hit_type == EventType.LAST_ANCHOR: + ho_info.append(time) + anchor_info.append(('B', x, y)) + + elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0: + curve_type = anchor_info[0][0] + span_duration = ho_info[4] - ho_info[2] + total_duration = time - ho_info[2] + + if total_duration == 0 or span_duration == 0: + continue + + slides = max(int(round(total_duration / span_duration)), 1) + control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info) + slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float)) + length = slider_path.get_distance() + + req_length = length * position_to_progress( + slider_path, + np.array((x, y)), + ) if has_pos else length - dist + + if req_length < 1e-4: + continue + + hit_object_strings.append( + f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}" + ) + + sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000 + timing_point_strings.append( + f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0" + ) + + new_combo = 0 + + # Write .osu file + with open(OSU_TEMPLATE_PATH, "r") as tf: + template = Template(tf.read()) + hit_objects = {"hit_objects": "\n".join(hit_object_strings)} + timing_points = {"timing_points": "\n".join(timing_point_strings)} + beatmap_config = dataclasses.asdict(self.beatmap_config) + result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points}) + processed_events.append(result) + + osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz") + with zipfile.ZipFile(osz_path, "w") as z: + for i, event in enumerate(processed_events): + osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}") + with open(osu_path, "w") as osu_file: + osu_file.write(event) + z.write(osu_path, os.path.basename(osu_path)) + z.write(self.audio_path, os.path.basename(self.audio_path)) + print(f"Mapset saved {osz_path}") + z.close() \ No newline at end of file diff --git a/osuT5/inference/preprocessor.py b/osuT5/inference/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..bae7d2951ad267d3ebb879fc44a9b2888d7dc2c8 --- /dev/null +++ b/osuT5/inference/preprocessor.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from pathlib import Path + +import torch +import numpy as np +import numpy.typing as npt +from omegaconf import DictConfig + +from osuT5.dataset.data_utils import load_audio_file + + +class Preprocessor(object): + def __init__(self, args: DictConfig): + """Preprocess audio data into sequences.""" + self.frame_seq_len = args.data.src_seq_len - 1 + self.frame_size = args.data.hop_length + self.sample_rate = args.data.sample_rate + self.samples_per_sequence = self.frame_seq_len * self.frame_size + self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride) + + def load(self, path: Path) -> npt.ArrayLike: + """Load an audio file as audio frames. Convert stereo to mono, normalize. + + Args: + path: Path to audio file. + + Returns: + samples: Audio time-series. + """ + return load_audio_file(path, self.sample_rate) + + def segment(self, samples: npt.ArrayLike) -> torch.Tensor: + """Segment audio samples into sequences. Sequences are flattened frames. + + Args: + samples: Audio time-series. + + Returns: + sequences: A list of sequences of shape (batch size, samples per sequence). + """ + samples = np.pad( + samples, + [0, self.sequence_stride - (len(samples) - self.samples_per_sequence) % self.sequence_stride], + ) + sequences = self.window(samples, self.samples_per_sequence, self.sequence_stride) + sequences = torch.from_numpy(sequences).to(torch.float32) + return sequences + + @staticmethod + def window(a, w, o, copy=False): + sh = (a.size - w + 1, w) + st = a.strides * 2 + view = np.lib.stride_tricks.as_strided(a, strides=st, shape=sh)[0::o] + if copy: + return view.copy() + else: + return view diff --git a/osuT5/inference/slider_path.py b/osuT5/inference/slider_path.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d008fa7db404a5039cb3f9ff1f59efe053e90a --- /dev/null +++ b/osuT5/inference/slider_path.py @@ -0,0 +1,230 @@ +import logging + +import numpy as np +from numpy.linalg import norm + +import osuT5.inference.path_approximator as path_approximator + + +def binary_search(array, target): + lower = 0 + upper = len(array) + while lower < upper: # use < instead of <= + x = lower + (upper - lower) // 2 + val = array[x] + if target == val: + return x + elif target > val: + if lower == x: # these two are the actual lines + break # you're looking for + lower = x + elif target < val: + upper = x + return ~upper + + +class SliderPath: + __slots__ = ( + "control_points", + "path_type", + "expected_distance", + "calculated_path", + "cumulative_length", + "is_initialised", + ) + + def __init__( + self, + path_type: str, + control_points: np.array, + expected_distance: float | None = None, + ) -> None: + self.control_points = control_points + self.path_type = path_type + self.expected_distance = expected_distance + + self.calculated_path = None + self.cumulative_length = None + + self.is_initialised = None + + self.ensure_initialised() + + def get_control_points(self) -> np.array: + self.ensure_initialised() + return self.control_points + + def get_distance(self) -> float: + self.ensure_initialised() + return 0 if len(self.cumulative_length) == 0 else self.cumulative_length[-1] + + def get_path_to_progress(self, path, p0, p1) -> None: + self.ensure_initialised() + + d0 = self.progress_to_distance(p0) + d1 = self.progress_to_distance(p1) + + path.clear() + + i = 0 + while i < len(self.calculated_path) and self.cumulative_length[i] < d0: + i += 1 + + path.append(self.interpolate_vertices(i, d0)) + + while i < len(self.calculated_path) and self.cumulative_length[i] < d1: + path.append(self.calculated_path[i]) + i += 1 + + path.append(self.interpolate_vertices(i, d1)) + + def position_at(self, progress) -> np.array: + self.ensure_initialised() + + d = self.progress_to_distance(progress) + return self.interpolate_vertices(self.index_of_distance(d), d) + + def ensure_initialised(self) -> None: + if self.is_initialised: + return + self.is_initialised = True + + self.control_points = [] if self.control_points is None else self.control_points + self.calculated_path = [] + self.cumulative_length = [] + + self.calculate_path() + self.calculate_cumulative_length() + + def calculate_subpath(self, sub_control_points) -> list: + if self.path_type == "Linear": + return path_approximator.approximate_linear(sub_control_points) + elif self.path_type == "PerfectCurve": + if len(self.get_control_points()) != 3 or len(sub_control_points) != 3: + return path_approximator.approximate_bezier(sub_control_points) + + subpath = path_approximator.approximate_circular_arc(sub_control_points) + + if len(subpath) == 0: + return path_approximator.approximate_bezier(sub_control_points) + + return subpath + elif self.path_type == "Catmull": + return path_approximator.approximate_catmull(sub_control_points) + else: + return path_approximator.approximate_bezier(sub_control_points) + + def calculate_path(self) -> None: + self.calculated_path.clear() + + start = 0 + end = 0 + + for i in range(len(self.get_control_points())): + end += 1 + + if ( + i == len(self.get_control_points()) - 1 + or ( + self.get_control_points()[i] == self.get_control_points()[i + 1] + ).all() + ): + cp_span = self.get_control_points()[start:end] + + for t in self.calculate_subpath(cp_span): + if ( + len(self.calculated_path) == 0 + or (self.calculated_path[-1] != t).any() + ): + self.calculated_path.append(t) + + start = end + + def calculate_cumulative_length(self) -> None: + length = 0 + + self.cumulative_length.clear() + self.cumulative_length.append(length) + + for i in range(len(self.calculated_path) - 1): + diff = self.calculated_path[i + 1] - self.calculated_path[i] + d = norm(diff) + + if ( + self.expected_distance is not None + and self.expected_distance - length < d + ): + self.calculated_path[i + 1] = ( + self.calculated_path[i] + + diff * (self.expected_distance - length) / d + ) + del self.calculated_path[i + 2 : len(self.calculated_path) - 2 - i] + + length = self.expected_distance + self.cumulative_length.append(length) + break + + length += d + self.cumulative_length.append(length) + + if ( + self.expected_distance is not None + and length < self.expected_distance + and len(self.calculated_path) > 1 + ): + diff = self.calculated_path[-1] - self.calculated_path[-2] + d = norm(diff) + + if d <= 0: + return + + self.calculated_path[-1] += ( + diff * (self.expected_distance - self.cumulative_length[-1]) / d + ) + self.cumulative_length[-1] = self.expected_distance + + def index_of_distance(self, d) -> int: + i = binary_search(self.cumulative_length, d) + if i < 0: + i = ~i + + return i + + def progress_to_distance(self, progress) -> float: + return np.clip(progress, 0, 1) * self.get_distance() + + def interpolate_vertices(self, i, d) -> np.array: + if len(self.calculated_path) == 0: + return np.zeros([2]) + + if i <= 0: + return self.calculated_path[0] + if i >= len(self.calculated_path): + return self.calculated_path[-1] + + p0 = self.calculated_path[i - 1] + p1 = self.calculated_path[i] + + d0 = self.cumulative_length[i - 1] + d1 = self.cumulative_length[i] + + if np.isclose(d0, d1): + return p0 + + w = (d - d0) / (d1 - d0) + return p0 + (p1 - p0) * w + + +if __name__ == "__main__": + path = SliderPath( + "Bezier", + 100 * np.array([[0, 0], [1, 1], [1, -1], [2, 0], [2, 0], [3, -1], [2, -2]]), + ) + p = np.vstack(path.calculated_path) + logging.info(p.shape) + + import matplotlib.pyplot as plt + + plt.axis("equal") + plt.plot(p[:, 0], p[:, 1], color="green") + plt.show() diff --git a/osuT5/inference/template.osu b/osuT5/inference/template.osu new file mode 100644 index 0000000000000000000000000000000000000000..ca7d29867df843cb534b5b7dad68bbfdeac448d0 --- /dev/null +++ b/osuT5/inference/template.osu @@ -0,0 +1,54 @@ +osu file format v14 + +[General] +AudioFilename: $audio_filename +AudioLeadIn: 0 +PreviewTime: -1 +Countdown: 0 +SampleSet: Soft +StackLeniency: 0.7 +Mode: 0 +LetterboxInBreaks: 0 +WidescreenStoryboard: 1 + +[Editor] +DistanceSpacing: 1.0 +BeatDivisor: 4 +GridSize: 8 +TimelineZoom: 1 + +[Metadata] +Title:$title +TitleUnicode:$title_unicode +Artist:$artist +ArtistUnicode:$artist_unicode +Creator:$creator +Version:$version +Source: +Tags: + +[Difficulty] +HPDrainRate:$hp_drain_rate +CircleSize:$circle_size +OverallDifficulty:$overall_difficulty +ApproachRate:$approach_rate +SliderMultiplier:$slider_multiplier +SliderTickRate:1 + +[Events] +//Background and Video events +//Break Periods +//Storyboard Layer 0 (Background) +//Storyboard Layer 1 (Fail) +//Storyboard Layer 2 (Pass) +//Storyboard Layer 3 (Foreground) +//Storyboard Layer 4 (Overlay) +//Storyboard Sound Samples + +[TimingPoints] +$timing_points + +[Colours] + +[HitObjects] +$hit_objects \ No newline at end of file diff --git a/osuT5/model/__init__.py b/osuT5/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f70b9d5ec4845936c6bf4cba8d2e400c24a4104 --- /dev/null +++ b/osuT5/model/__init__.py @@ -0,0 +1 @@ +from .osu_t import OsuT diff --git a/osuT5/model/__pycache__/__init__.cpython-311.pyc b/osuT5/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4140de77506b85e00275b0a64028095698ef74 Binary files /dev/null and b/osuT5/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/osuT5/model/__pycache__/__init__.cpython-39.pyc b/osuT5/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b55752dc9d4202128fa986874bbb0a689fc26887 Binary files /dev/null and b/osuT5/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/osuT5/model/__pycache__/osu_t.cpython-311.pyc b/osuT5/model/__pycache__/osu_t.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c301cc71d8575fbf0dc3ba9421c0bc7a32b2b7f Binary files /dev/null and b/osuT5/model/__pycache__/osu_t.cpython-311.pyc differ diff --git a/osuT5/model/__pycache__/osu_t.cpython-39.pyc b/osuT5/model/__pycache__/osu_t.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ea8addc9c33ab1965366212d11b39c3c3be49f1 Binary files /dev/null and b/osuT5/model/__pycache__/osu_t.cpython-39.pyc differ diff --git a/osuT5/model/__pycache__/spectrogram.cpython-311.pyc b/osuT5/model/__pycache__/spectrogram.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c98853af3781f89bc6da1a2060c4b5224f925e41 Binary files /dev/null and b/osuT5/model/__pycache__/spectrogram.cpython-311.pyc differ diff --git a/osuT5/model/__pycache__/spectrogram.cpython-39.pyc b/osuT5/model/__pycache__/spectrogram.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8895a6ca3e5c6f5a11e9893f60efcc8f5e64cfb1 Binary files /dev/null and b/osuT5/model/__pycache__/spectrogram.cpython-39.pyc differ diff --git a/osuT5/model/osu_t.py b/osuT5/model/osu_t.py new file mode 100644 index 0000000000000000000000000000000000000000..75212fbc3cdeea82a7eee9073591d2c9ed7d836e --- /dev/null +++ b/osuT5/model/osu_t.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from omegaconf import DictConfig +from transformers import T5Config, T5ForConditionalGeneration, WhisperForConditionalGeneration, WhisperConfig +from transformers.modeling_outputs import Seq2SeqLMOutput + +from osuT5.model.spectrogram import MelSpectrogram +from osuT5.tokenizer import Tokenizer + + +def get_backbone_model(args, tokenizer: Tokenizer): + if args.model.name.startswith("google/t5"): + config = T5Config.from_pretrained(args.model.name) + print(f"old config: num_heads={config.num_heads}, num_layers={config.num_layers}, d_model={config.d_model}, d_ff={config.d_ff}") + + + elif args.model.name.startswith("openai/whisper"): + config = WhisperConfig.from_pretrained(args.model.name) + else: + raise NotImplementedError + + config.vocab_size = tokenizer.vocab_size_out + + if hasattr(args.model, "overwrite"): + for k, v in args.model.overwrite.items(): + assert hasattr(config, k), f"config does not have attribute {k}" + setattr(config, k, v) + + if hasattr(args.model, "add_config"): + for k, v in args.model.add_config.items(): + assert not hasattr(config, k), f"config already has attribute {k}" + setattr(config, k, v) + + if args.model.name.startswith("google/t5"): + model = T5ForConditionalGeneration(config) + elif args.model.name.startswith("openai/whisper"): + config.num_mel_bins = config.d_model + config.pad_token_id = tokenizer.pad_id + config.bos_token_id = tokenizer.sos_id + config.eos_token_id = tokenizer.eos_id + config.max_source_positions = args.data.src_seq_len // 2 + config.max_target_positions = args.data.tgt_seq_len + model = WhisperForConditionalGeneration(config) + else: + raise NotImplementedError + + return model, config.d_model + + +class OsuT(nn.Module): + def __init__(self, args: DictConfig, tokenizer: Tokenizer): + super().__init__() + + self.transformer, d_model = get_backbone_model(args, tokenizer) + self.num_classes = tokenizer.num_classes + self.input_features = args.model.input_features + + + + self.decoder_embedder = nn.Embedding(tokenizer.vocab_size_in, d_model) + self.decoder_embedder.weight.data.normal_(mean=0.0, std=1.0) + + self.dropout = nn.Dropout(p=0.1) + + self.spectrogram = MelSpectrogram( + args.model.spectrogram.sample_rate, args.model.spectrogram.n_fft, + args.model.spectrogram.n_mels, args.model.spectrogram.hop_length) + #self.norm_layer = nn.LayerNorm(args.model.spectrogram.n_mels) + + self.do_style_embed = args.model.do_style_embed + + if self.do_style_embed: + self.style_embedder = LabelEmbedder(self.num_classes, d_model) + self.encoder_embedder = nn.Linear(args.model.spectrogram.n_mels + d_model, d_model) + nn.init.normal_(self.style_embedder.embedding_table.weight, std=0.02) + else: + self.encoder_embedder = nn.Linear(args.model.spectrogram.n_mels, d_model) + + def forward(self, frames: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, + beatmap_idx: Optional[torch.Tensor] = None, encoder_outputs: Optional[torch.FloatTensor] = None, **kwargs) -> Seq2SeqLMOutput: + if beatmap_idx is None: + batch_size = frames.shape[0] if frames is not None else decoder_input_ids.shape[0] + device = frames.device if frames is not None else decoder_input_ids.device + beatmap_idx = torch.full([batch_size], self.num_classes, dtype=torch.long, device=device) + + inputs_embeds = None + if encoder_outputs is None: + frames = self.spectrogram(frames) + #frames = self.norm_layer(frames) + if self.do_style_embed: + style_embeds = self.style_embedder(beatmap_idx) + frames_concat = torch.concatenate((frames, style_embeds.unsqueeze(1).expand((-1, frames.shape[1], -1))), -1) + inputs_embeds = self.encoder_embedder(frames_concat) + else: + inputs_embeds = self.encoder_embedder(frames) + + decoder_inputs_embeds = self.dropout(self.decoder_embedder(decoder_input_ids)) + #inputs_embeds = self.dropout(inputs_embeds) + + if self.input_features: + input_features = torch.swapaxes(inputs_embeds, 1, 2) + output = self.transformer.forward(input_features=input_features, decoder_inputs_embeds=decoder_inputs_embeds, + encoder_outputs=encoder_outputs, **kwargs) + else: + output = self.transformer.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, + encoder_outputs=encoder_outputs, **kwargs) + + return output + + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding( + num_classes + 1, + hidden_size, + ) + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings diff --git a/osuT5/model/spectrogram.py b/osuT5/model/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5df1a6d9d90d1ef95a3da9facd32f92aeb7128 --- /dev/null +++ b/osuT5/model/spectrogram.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +from nnAudio import features + + +class MelSpectrogram(nn.Module): + def __init__( + self, + sample_rate: int = 16000, + n_ftt: int = 2048, + n_mels: int = 512, + hop_length: int = 128, + ): + """ + Melspectrogram transformation layer, supports on-the-fly processing on GPU. + + Attributes: + sample_rate: The sampling rate for the input audio. + n_ftt: The window size for the STFT. + n_mels: The number of Mel filter banks. + hop_length: The hop (or stride) size. + """ + super().__init__() + self.transform = features.MelSpectrogram( + sr=sample_rate, + n_fft=n_ftt, + n_mels=n_mels, + hop_length=hop_length, + center=True, + fmin=0, + fmax=sample_rate // 2, + pad_mode="constant", + ) + + def forward(self, samples: torch.tensor) -> torch.tensor: + """ + Convert a batch of audio frames into a batch of Mel spectrogram frames. + + For each item in the batch: + 1. pad left and right ends of audio by n_fft // 2. + 2. run STFT with window size of |n_ftt| and stride of |hop_length|. + 3. convert result into mel-scale. + 4. therefore, n_frames = n_samples // hop_length + 1. + + Args: + samples: Audio time-series (batch size, n_samples). + + Returns: + A batch of Mel spectrograms of size (batch size, n_frames, n_mels). + """ + spectrogram = self.transform(samples) + spectrogram = spectrogram.permute(0, 2, 1) + return spectrogram diff --git a/osuT5/model/t5.py b/osuT5/model/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad98aceb52f528dd3af5c2964b4dc67b70eb4b9 --- /dev/null +++ b/osuT5/model/t5.py @@ -0,0 +1,643 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified by nanoT5 authors +# https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/t5_model.py + +import copy +import math +from typing import Optional +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.modeling_outputs import ModelOutput +from transformers.models.t5.configuration_t5 import T5Config +from transformers.models.t5.modeling_t5 import ( + T5LayerNorm, + T5DenseGatedActDense, +) + +from .spectrogram import MelSpectrogram + + +@dataclass +class EncoderOutput(ModelOutput): + hidden_states: torch.FloatTensor = None + attention_mask: torch.FloatTensor = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutput): + loss: torch.FloatTensor = None + logits: torch.FloatTensor = None + encoder_outputs: EncoderOutput = None + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + assert config.is_gated_act + self.DenseReluDense = T5DenseGatedActDense(config) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + batch_size, seq_length = hidden_states.shape[:2] + real_seq_length = seq_length + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + query_states = self.q(hidden_states) + if key_value_states is None: + key_states, value_states = self.k(hidden_states), self.v(hidden_states) + else: + key_states, value_states = self.k(key_value_states), self.v( + key_value_states + ) + query_states, key_states, value_states = ( + shape(query_states), + shape(key_states), + shape(value_states), + ) + + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + if mask is not None: + # Masking happens here, masked elements in the mask have the value of -inf + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + return (attn_output, position_bias) + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + ): + normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Relative position weights + + if self.is_decoder and encoder_hidden_states is not None: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + ) + hidden_states = cross_attention_outputs[0] + + # Keep relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + outputs = (hidden_states,) + outputs = outputs + attention_outputs + + return outputs # hidden-states, (self-attention position bias), (cross-attention position bias) + + +class T5Stack(nn.Module, ModuleUtilsMixin): + def __init__(self, config, embed_tokens): + super().__init__() + assert embed_tokens is not None + + self.config = config + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] + ) + + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> EncoderOutput: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = inputs_embeds.size() + batch_size = input_shape[0] + seq_length = input_shape[1] + input_shape = (batch_size, seq_length) + + if hasattr(self.config, "is_bf16") and self.config.is_bf16: + inputs_embeds = inputs_embeds.to(torch.bfloat16) + + # Masking + if attention_mask is None: + attention_mask = torch.ones( + batch_size, seq_length, device=inputs_embeds.device + ) + + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for _, layer_module in enumerate(self.block): + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + ) + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + position_bias = layer_outputs[1] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[2] + + hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states) + hidden_states = self.dropout(hidden_states) + + return EncoderOutput( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + +class T5(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + config.is_encoder_decoder = False + assert not config.tie_word_embeddings + + self.config = config + self.model_dim = config.d_model + + self.spectrogram = MelSpectrogram( + config.sample_rate, config.n_fft, config.n_mels, config.hop_length + ) + self.encoder_embedder = nn.Linear(config.n_mels, config.d_model) + self.decoder_embedder = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + self.encoder = T5Stack(encoder_config, self.encoder_embedder) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.decoder_embedder) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.generation_config = None + + self.apply(self._init_weights) + + def generate( + self, + frames: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + max_length=None, + **kwargs, + ) -> torch.LongTensor: + """ + frames: B x L_encoder x mel_bins, float32 + attention_mask: B x L_encoder, int64 + 1 for tokens to attend to, 0 for tokens to ignore + + Generation: + Starts with [SOS], ends with [EOS], padding is [PAD] (see Tokenizer) + """ + B, _ = frames.size() + SOS_TOKEN_ID = self.config.decoder_start_token_id + PAD_TOKEN_ID = self.config.pad_token_id + EOS_TOKEN_ID = self.config.eos_token_id + labels = torch.ones(B, 1, dtype=torch.long, device=frames.device) * SOS_TOKEN_ID + encoder_outputs = None + + for _ in range(max_length): + out = self.forward( + frames=frames, + attention_mask=attention_mask, + decoder_input_ids=labels, + encoder_outputs=encoder_outputs, + ) + encoder_outputs = out.encoder_outputs + top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1) + labels = torch.cat([labels, top_labels], dim=-1) + + if (labels == EOS_TOKEN_ID).sum(-1).clamp(min=0, max=1).sum().item() == B: + break + + labels[:, -1] = EOS_TOKEN_ID + + # Mask out the padding, i.e., all positions after the first 1 with 0 + B, L = labels.size() + mask = torch.arange(L, device=labels.device).unsqueeze(0) <= ( + labels == EOS_TOKEN_ID + ).long().argmax(-1).unsqueeze(-1) + labels = labels.masked_fill(~mask, PAD_TOKEN_ID) + + return labels + + def forward( + self, + frames: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + tokens: Optional[torch.LongTensor] = None, + encoder_outputs=None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> Seq2SeqLMOutput: + """ + frames: B x L_encoder x mel_bins, float32 + attention_mask: B x L_encoder, int64 + 1 for tokens to attend to, 0 for tokens to ignore + tokens: B x L_decoder, int64 + """ + if encoder_outputs is None: + encoder_outputs = self.encoder( + frames, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + hidden_states = encoder_outputs.hidden_states + + if tokens is not None and decoder_input_ids is None: + decoder_input_ids = self._shift_right(tokens) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + ) + + sequence_output = decoder_outputs[0] + lm_logits = self.lm_head(sequence_output) + + loss = None + if tokens is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), tokens.view(-1)) + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + encoder_outputs=encoder_outputs, + ) + + def _init_weights(self, module): + factor = ( + self.config.initializer_factor + ) # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5)): + module.decoder_embedder.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseGatedActDense): + d_ff, d_model = module.wi_0.weight.data.size() + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) + elif isinstance(module, T5Attention): + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) + if hasattr(module, "relative_attention_bias"): + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) + + def _shift_right(self, input_ids): + SOS_TOKEN_ID = self.config.decoder_start_token_id + PAD_TOKEN_ID = self.config.pad_token_id + + assert SOS_TOKEN_ID is not None and PAD_TOKEN_ID is not None + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = SOS_TOKEN_ID + + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, PAD_TOKEN_ID) + + return shifted_input_ids diff --git a/osuT5/tokenizer/__init__.py b/osuT5/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..682528aa7dc487d1f5807f7ad59c9a4290f16bee --- /dev/null +++ b/osuT5/tokenizer/__init__.py @@ -0,0 +1,2 @@ +from .event import * +from .tokenizer import Tokenizer diff --git a/osuT5/tokenizer/__pycache__/__init__.cpython-311.pyc b/osuT5/tokenizer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da5ece1c6c42698f106de1eb3161a8674f9985e Binary files /dev/null and b/osuT5/tokenizer/__pycache__/__init__.cpython-311.pyc differ diff --git a/osuT5/tokenizer/__pycache__/__init__.cpython-39.pyc b/osuT5/tokenizer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a963623ba836ee9f7be8f3b60c0e0971b80be395 Binary files /dev/null and b/osuT5/tokenizer/__pycache__/__init__.cpython-39.pyc differ diff --git a/osuT5/tokenizer/__pycache__/event.cpython-311.pyc b/osuT5/tokenizer/__pycache__/event.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7262c70323253508498b55100a5c358e6550d7a Binary files /dev/null and b/osuT5/tokenizer/__pycache__/event.cpython-311.pyc differ diff --git a/osuT5/tokenizer/__pycache__/event.cpython-39.pyc b/osuT5/tokenizer/__pycache__/event.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b40e007d74a5e67d8e75913f7306c4a37861e283 Binary files /dev/null and b/osuT5/tokenizer/__pycache__/event.cpython-39.pyc differ diff --git a/osuT5/tokenizer/__pycache__/tokenizer.cpython-311.pyc b/osuT5/tokenizer/__pycache__/tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32106ac743321617536f8af29e6d0ef2868e5a30 Binary files /dev/null and b/osuT5/tokenizer/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/osuT5/tokenizer/__pycache__/tokenizer.cpython-39.pyc b/osuT5/tokenizer/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b14d3552d74155d85ee41fbdcc6ebb52971d91d7 Binary files /dev/null and b/osuT5/tokenizer/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/osuT5/tokenizer/event.py b/osuT5/tokenizer/event.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b2c5f25fe62c272076ab8e2e6b6a40afd297f8 --- /dev/null +++ b/osuT5/tokenizer/event.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import dataclasses +from enum import Enum + + +class EventType(Enum): + TIME_SHIFT = "t" + DISTANCE = "dist" + NEW_COMBO = "new_combo" + CIRCLE = "circle" + SPINNER = "spinner" + SPINNER_END = "spinner_end" + SLIDER_HEAD = "slider_head" + BEZIER_ANCHOR = "bezier_anchor" + PERFECT_ANCHOR = "perfect_anchor" + CATMULL_ANCHOR = "catmull_anchor" + RED_ANCHOR = "red_anchor" + LAST_ANCHOR = "last_anchor" + SLIDER_END = "slider_end" + STYLE = "style" + DIFFICULTY = "difficulty" + POS_X = "pos_x" + POS_Y = "pos_y" + + +@dataclasses.dataclass +class EventRange: + type: EventType + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: EventType + value: int = 0 + + def __repr__(self) -> str: + return f"{self.type.value}{self.value}" + + def __str__(self) -> str: + return f"{self.type.value}{self.value}" diff --git a/osuT5/tokenizer/tokenizer.py b/osuT5/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2a3b3fe3d5919d2b6c84ad98c505cc2e79c0ca20 --- /dev/null +++ b/osuT5/tokenizer/tokenizer.py @@ -0,0 +1,237 @@ +import json +import pickle +from pathlib import Path + +import numpy as np +from omegaconf import DictConfig +from tqdm import tqdm + +from .event import Event, EventType, EventRange + +MILISECONDS_PER_SECOND = 1000 +MILISECONDS_PER_STEP = 10 + + +class Tokenizer: + __slots__ = [ + "_offset", + "event_ranges", + "input_event_ranges", + "num_classes", + "num_diff_classes", + "max_difficulty", + "event_range", + "event_start", + "event_end", + "vocab_size_out", + "vocab_size_in", + "beatmap_idx", + ] + + def __init__(self, args: DictConfig = None): + """Fixed vocabulary tokenizer.""" + self._offset = 3 + self.beatmap_idx: dict[int, int] = {} + + if args is not None: + miliseconds_per_sequence = ((args.data.src_seq_len - 1) * args.model.spectrogram.hop_length * + MILISECONDS_PER_SECOND / args.model.spectrogram.sample_rate) + max_time_shift = int(miliseconds_per_sequence / MILISECONDS_PER_STEP) + min_time_shift = -max_time_shift if args.data.add_pre_tokens or args.data.add_pre_tokens_at_step >= 0 else 0 + self.event_ranges = [EventRange(EventType.TIME_SHIFT, min_time_shift, max_time_shift)] + + self.input_event_ranges: list[EventRange] = [] + if args.data.style_token_index >= 0: + self.input_event_ranges.append(EventRange(EventType.STYLE, 0, args.data.num_classes)) + if args.data.diff_token_index >= 0: + self.input_event_ranges.append(EventRange(EventType.DIFFICULTY, 0, args.data.num_diff_classes)) + + self.num_classes = args.data.num_classes + self.num_diff_classes = args.data.num_diff_classes + self.max_difficulty = args.data.max_diff + + self._init_beatmap_idx(args) + else: + self.event_ranges = [EventRange(EventType.TIME_SHIFT, -512, 512)] + self.input_event_ranges = [] + self.num_classes = 0 + self.num_diff_classes = 0 + self.max_difficulty = 0 + + self.event_ranges: list[EventRange] = self.event_ranges + [ + EventRange(EventType.DISTANCE, 0, 640), + EventRange(EventType.NEW_COMBO, 0, 0), + EventRange(EventType.CIRCLE, 0, 0), + EventRange(EventType.SPINNER, 0, 0), + EventRange(EventType.SPINNER_END, 0, 0), + EventRange(EventType.SLIDER_HEAD, 0, 0), + EventRange(EventType.BEZIER_ANCHOR, 0, 0), + EventRange(EventType.PERFECT_ANCHOR, 0, 0), + EventRange(EventType.CATMULL_ANCHOR, 0, 0), + EventRange(EventType.RED_ANCHOR, 0, 0), + EventRange(EventType.LAST_ANCHOR, 0, 0), + EventRange(EventType.SLIDER_END, 0, 0), + ] + + self.event_range: dict[EventType, EventRange] = {er.type: er for er in self.event_ranges} | {er.type: er for er in self.input_event_ranges} + + self.event_start: dict[EventType, int] = {} + self.event_end: dict[EventType, int] = {} + offset = self._offset + for er in self.event_ranges: + self.event_start[er.type] = offset + offset += er.max_value - er.min_value + 1 + self.event_end[er.type] = offset + for er in self.input_event_ranges: + self.event_start[er.type] = offset + offset += er.max_value - er.min_value + 1 + self.event_end[er.type] = offset + + self.vocab_size_out: int = self._offset + sum( + er.max_value - er.min_value + 1 for er in self.event_ranges + ) + self.vocab_size_in: int = self.vocab_size_out + sum( + er.max_value - er.min_value + 1 for er in self.input_event_ranges + ) + + @property + def pad_id(self) -> int: + """[PAD] token for padding.""" + return 0 + + @property + def sos_id(self) -> int: + """[SOS] token for start-of-sequence.""" + return 1 + + @property + def eos_id(self) -> int: + """[EOS] token for end-of-sequence.""" + return 2 + + def decode(self, token_id: int) -> Event: + """Converts token ids into Event objects.""" + offset = self._offset + for er in self.event_ranges: + if offset <= token_id <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + token_id - offset) + offset += er.max_value - er.min_value + 1 + for er in self.input_event_ranges: + if offset <= token_id <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + token_id - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"id {token_id} is not mapped to any event") + + def encode(self, event: Event) -> int: + """Converts Event objects into token ids.""" + if event.type not in self.event_range: + raise ValueError(f"unknown event type: {event.type}") + + er = self.event_range[event.type] + offset = self.event_start[event.type] + + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f"event value {event.value} in {event} is not within range " + f"[{er.min_value}, {er.max_value}] for event type {event.type}" + ) + + return offset + event.value - er.min_value + + def event_type_range(self, event_type: EventType) -> tuple[int, int]: + """Get the token id range of each Event type.""" + if event_type not in self.event_range: + raise ValueError(f"unknown event type: {event_type}") + + er = self.event_range[event_type] + offset = self.event_start[event_type] + return offset, offset + (er.max_value - er.min_value) + + def encode_diff_event(self, diff: float) -> Event: + """Converts difficulty value into event.""" + return Event(type=EventType.DIFFICULTY, value=np.clip( + int(diff * self.num_diff_classes / self.max_difficulty), 0, self.num_diff_classes - 1)) + + def encode_diff(self, diff: float) -> int: + """Converts difficulty value into token id.""" + return self.encode(self.encode_diff_event(diff)) + + @property + def diff_unk(self) -> int: + """Gets the unknown difficulty value token id.""" + return self.encode(Event(type=EventType.DIFFICULTY, value=self.num_diff_classes)) + + def encode_style_event(self, beatmap_id: int) -> Event: + """Converts beatmap id into style event.""" + style_idx = self.beatmap_idx.get(beatmap_id, self.num_classes) + return Event(type=EventType.STYLE, value=style_idx) + + def encode_style(self, beatmap_id: int) -> int: + """Converts beatmap id into token id.""" + return self.encode(self.encode_style_event(beatmap_id)) + + def encode_style_idx(self, beatmap_idx: int) -> int: + """Converts beatmap idx into token id.""" + return self.encode(Event(type=EventType.STYLE, value=beatmap_idx)) + + @property + def style_unk(self) -> int: + """Gets the unknown style value token id.""" + return self.encode(Event(type=EventType.STYLE, value=self.num_classes)) + + def _init_beatmap_idx(self, args: DictConfig) -> None: + """Initializes and caches the beatmap index.""" + if args is None or "train_dataset_path" not in args.data: + return + + path = Path(args.data.train_dataset_path) + cache_path = path / "beatmap_idx.pickle" + + if cache_path.exists(): + with open(path / "beatmap_idx.pickle", "rb") as f: + self.beatmap_idx = pickle.load(f) + return + + print("Caching beatmap index...") + + for track in tqdm(path.iterdir()): + if not track.is_dir(): + continue + metadata_file = track / "metadata.json" + with open(metadata_file) as f: + metadata = json.load(f) + for beatmap_name in metadata["Beatmaps"]: + beatmap_metadata = metadata["Beatmaps"][beatmap_name] + self.beatmap_idx[beatmap_metadata["BeatmapId"]] = beatmap_metadata["Index"] + + with open(cache_path, "wb") as f: + pickle.dump(self.beatmap_idx, f) + + def state_dict(self): + return { + "event_ranges": self.event_ranges, + "input_event_ranges": self.input_event_ranges, + "num_classes": self.num_classes, + "num_diff_classes": self.num_diff_classes, + "max_difficulty": self.max_difficulty, + "event_range": self.event_range, + "event_start": self.event_start, + "event_end": self.event_end, + "vocab_size_out": self.vocab_size_out, + "vocab_size_in": self.vocab_size_in, + "beatmap_idx": self.beatmap_idx, + } + + def load_state_dict(self, state_dict): + self.event_ranges = state_dict["event_ranges"] + self.input_event_ranges = state_dict["input_event_ranges"] + self.num_classes = state_dict["num_classes"] + self.num_diff_classes = state_dict["num_diff_classes"] + self.max_difficulty = state_dict["max_difficulty"] + self.event_range = state_dict["event_range"] + self.event_start = state_dict["event_start"] + self.event_end = state_dict["event_end"] + self.vocab_size_out = state_dict["vocab_size_out"] + self.vocab_size_in = state_dict["vocab_size_in"] + self.beatmap_idx = state_dict["beatmap_idx"] diff --git a/osuT5/utils/__init__.py b/osuT5/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50336649f8935f6c242d80e1c8d884669efedd47 --- /dev/null +++ b/osuT5/utils/__init__.py @@ -0,0 +1,3 @@ +from .init_utils import * +from .model_utils import * +#from .train_utils import * diff --git a/osuT5/utils/__pycache__/__init__.cpython-311.pyc b/osuT5/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..091c2b54bbc93cea50aea7871abd7763442573af Binary files /dev/null and b/osuT5/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/osuT5/utils/__pycache__/copied_utils.cpython-311.pyc b/osuT5/utils/__pycache__/copied_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71ffb8e6392594b77e6d4bbdde7197d0aa5d1001 Binary files /dev/null and b/osuT5/utils/__pycache__/copied_utils.cpython-311.pyc differ diff --git a/osuT5/utils/__pycache__/init_utils.cpython-311.pyc b/osuT5/utils/__pycache__/init_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b512728752435291b3e2d05cb80e42bf45bb5fe Binary files /dev/null and b/osuT5/utils/__pycache__/init_utils.cpython-311.pyc differ diff --git a/osuT5/utils/__pycache__/log_utils.cpython-311.pyc b/osuT5/utils/__pycache__/log_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f493b2212b2738b94ee585a988e5fa6cea9a4a Binary files /dev/null and b/osuT5/utils/__pycache__/log_utils.cpython-311.pyc differ diff --git a/osuT5/utils/__pycache__/model_utils.cpython-311.pyc b/osuT5/utils/__pycache__/model_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..073c1d8466d84d575c705439cdd8dfeaba7503cd Binary files /dev/null and b/osuT5/utils/__pycache__/model_utils.cpython-311.pyc differ diff --git a/osuT5/utils/__pycache__/train_utils.cpython-311.pyc b/osuT5/utils/__pycache__/train_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d221f2d17c6d23e91515e35cfbdf79a1f8ff65a Binary files /dev/null and b/osuT5/utils/__pycache__/train_utils.cpython-311.pyc differ diff --git a/osuT5/utils/copied_utils.py b/osuT5/utils/copied_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56db95a535d2c6b92b3ac58899d5197eff3d6822 --- /dev/null +++ b/osuT5/utils/copied_utils.py @@ -0,0 +1,609 @@ +from typing import Dict, List +import numpy as np +from transformers import BatchEncoding +from dataclasses import dataclass +from transformers import AutoTokenizer +import torch +import math +from torch.optim import Optimizer +from typing import Iterable, Tuple +from torch import nn +import random +import string + + +@dataclass +class DataCollatorForT5MLM: + """ + [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py] + Data collator used for T5 span-masked language modeling. + It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length. + For more information on how T5 span-masked language modeling works, one can take a look + at the `official paper `__ + or the `official code for preprocessing `__ . + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + noise_density (:obj:`float`): + The probability with which to (randomly) mask tokens in the input. + mean_noise_span_length (:obj:`float`): + The average span length of the masked tokens. + input_length (:obj:`int`): + The expected input length after masking. + target_length (:obj:`int`): + The expected target length after masking. + pad_token_id: (:obj:`int`): + The pad token id of the model + decoder_start_token_id: (:obj:`int): + The decoder start token id of the model + """ + + tokenizer: AutoTokenizer + noise_density: float + mean_noise_span_length: float + input_length: int + target_length: int + pad_token_id: int + + def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding: + # convert list to dict and tensorize input + batch = BatchEncoding( + { + k: np.array([examples[i][k] for i in range(len(examples))]) + for k, v in examples[0].items() + } + ) + + input_ids = batch["input_ids"] + batch_size, expandend_input_length = input_ids.shape + + mask_indices = np.asarray( + [ + self.random_spans_noise_mask(expandend_input_length) + for i in range(batch_size) + ] + ) + labels_mask = ~mask_indices + + input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) + labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) + + batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel) + batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel) + + if batch["input_ids"].shape[-1] != self.input_length: + raise ValueError( + f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but" + f" should be {self.input_length}." + ) + + if batch["labels"].shape[-1] != self.target_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be" + f" {self.target_length}." + ) + + batch = {k: torch.from_numpy(v) for k, v in batch.items()} + return batch + + def create_sentinel_ids(self, mask_indices): + """ + Sentinel ids creation given the indices that should be masked. + The start indices of each mask are replaced by the sentinel ids in increasing + order. Consecutive mask indices to be deleted are replaced with `-1`. + """ + start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices + start_indices[:, 0] = mask_indices[:, 0] + + sentinel_ids = np.where( + start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices + ) + sentinel_ids = np.where( + sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0 + ) + sentinel_ids -= mask_indices - start_indices + + return sentinel_ids + + def filter_input_ids(self, input_ids, sentinel_ids): + """ + Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. + This will reduce the sequence length from `expanded_inputs_length` to `input_length`. + """ + batch_size = input_ids.shape[0] + + input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) + # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are + # masked tokens coming after sentinel tokens and should be removed + input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) + input_ids = np.concatenate( + [ + input_ids, + np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32), + ], + axis=-1, + ) + return input_ids + + def random_spans_noise_mask(self, length): + """This function is copy of `random_spans_helper `__ . + + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + + Args: + length: an int32 scalar (length of the incoming token sequence) + noise_density: a float - approximate density of output mask + mean_noise_span_length: a number + + Returns: + a boolean tensor with shape [length] + """ + + orig_length = length + + num_noise_tokens = int(np.round(length * self.noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add + up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) + nonnoise_span_lengths = _random_segmentation( + num_nonnoise_tokens, num_noise_spans + ) + + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2], + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length,), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] + + +def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length): + """This function is copy of `random_spans_helper `__ . + + [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py] + Training parameters to avoid padding with random_spans_noise_mask. + When training a model with random_spans_noise_mask, we would like to set the other + training hyperparmeters in a way that avoids padding. + This function helps us compute these hyperparameters. + We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, + and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. + This function tells us the required number of tokens in the raw example (for split_tokens()) + as well as the length of the encoded targets. Note that this function assumes + the inputs and targets will have EOS appended and includes that in the reported length. + + Args: + inputs_length: an integer - desired length of the tokenized inputs sequence + noise_density: a float + mean_noise_span_length: a float + Returns: + tokens_length: length of original text in tokens + targets_length: an integer - length in tokens of encoded targets sequence + """ + + def _tokens_length_to_inputs_length_targets_length(tokens_length): + num_noise_tokens = int(round(tokens_length * noise_density)) + num_nonnoise_tokens = tokens_length - num_noise_tokens + num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) + # inputs contain all nonnoise tokens, sentinels for all noise spans + # and one EOS token. + _input_length = num_nonnoise_tokens + num_noise_spans + 1 + _output_length = num_noise_tokens + num_noise_spans + 1 + return _input_length, _output_length + + tokens_length = inputs_length + + while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: + tokens_length += 1 + + inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) + + # minor hack to get the targets length to be equal to inputs length + # which is more likely to have been set to a nice round number. + if noise_density == 0.5 and targets_length > inputs_length: + tokens_length -= 1 + targets_length -= 1 + return tokens_length, targets_length + + +class AdamWScale(Optimizer): + """ + This AdamW implementation is copied from Huggingface. + We modified it with Adagrad scaling by rms of a weight tensor + + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 1e-3): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-6): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) + super().__init__(params, defaults) + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def step(self, closure=None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + + state = self.state[p] + beta1, beta2 = group["betas"] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # /Adapt Step from Adafactor + step_size = step_size * max(1e-3, self._rms(p.data)) + # /Adapt Step from Adafactor + + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) + + return loss + + +def tokenize_function(examples, tokenizer, in_length): + tokenizer_out = tokenizer( + text=examples["text"], + return_attention_mask=False, + ) + + input_ids = tokenizer_out["input_ids"] + + concatenated_ids = np.concatenate(input_ids) + + total_length = concatenated_ids.shape[0] + total_length = (total_length // in_length) * in_length + + concatenated_ids = concatenated_ids[:total_length].reshape(-1, in_length) + result = {"input_ids": concatenated_ids} + + return result + + +from transformers.data.data_collator import * +@dataclass +class DataCollatorForNI: + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_source_length: Optional[int] = None + max_target_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + add_task_name: bool = False + add_task_definition: bool = True + num_pos_examples: int = 0 + num_neg_examples: int = 0 + add_explanation: bool = False + tk_instruct: bool = False + text_only: bool = False + + def __call__(self, batch, return_tensors=None): + + if return_tensors is None: + return_tensors = self.return_tensors + + sources = [] + for instance in batch: + if self.tk_instruct: + all_valid_encodings = [ + # instruction only + { + "add_task_name": False, + "add_task_definition": True, + "num_pos_examples": 0, + "num_neg_examples": 0, + "add_explanation": False, + }, + # example only + { + "add_task_name": False, + "add_task_definition": False, + "num_pos_examples": 2, + "num_neg_examples": 0, + "add_explanation": False, + }, + # instruction + pos examples + { + "add_task_name": False, + "add_task_definition": True, + "num_pos_examples": 2, + "num_neg_examples": 0, + "add_explanation": False, + }, + # instruction + pos examples + neg examples + { + "add_task_name": False, + "add_task_definition": True, + "num_pos_examples": 2, + "num_neg_examples": 2, + "add_explanation": False, + }, + # instruction + pos (w. explanation) + { + "add_task_name": False, + "add_task_definition": True, + "num_pos_examples": 2, + "num_neg_examples": 0, + "add_explanation": True, + }, + ] + encoding_schema = random.choice(all_valid_encodings) + add_task_name = encoding_schema["add_task_name"] + add_task_definition = encoding_schema["add_task_definition"] + num_pos_examples = encoding_schema["num_pos_examples"] + num_neg_examples = encoding_schema["num_neg_examples"] + add_explanation = encoding_schema["add_explanation"] + else: + add_task_name = self.add_task_name + add_task_definition = self.add_task_definition + num_pos_examples = self.num_pos_examples + num_neg_examples = self.num_neg_examples + add_explanation = self.add_explanation + + task_input = "" + # add the input first. + task_input += "Now complete the following example -\n" + task_input += f"Input: {instance['Instance']['input'].strip()}" + if not task_input[-1] in string.punctuation: + task_input += "." + task_input += "\n" + task_input += "Output: " + + task_name = "" + if add_task_name: + task_name += instance["Task"] + ". " + + definition = "" + if add_task_definition: + if isinstance(instance["Definition"], list): + definition = ( + "Definition: " + instance["Definition"][0].strip() + ) + else: + definition = "Definition: " + instance["Definition"].strip() + if not definition[-1] in string.punctuation: + definition += "." + definition += "\n\n" + + # try to add positive examples. + pos_examples = [] + for idx, pos_example in enumerate( + instance["Positive Examples"][:num_pos_examples] + ): + pos_example_str = f" Positive Example {idx+1} -\n" + pos_example_str += f"Input: {pos_example['input'].strip()}" + if not pos_example_str[-1] in string.punctuation: + pos_example_str += "." + pos_example_str += "\n" + pos_example_str += f" Output: {pos_example['output'].strip()}" + if not pos_example_str[-1] in string.punctuation: + pos_example_str += "." + pos_example_str += "\n" + if add_explanation and "explanation" in pos_example: + pos_example_str += ( + f" Explanation: {pos_example['explanation'].strip()}" + ) + if not pos_example_str[-1] in string.punctuation: + pos_example_str += "." + pos_example_str += "\n" + pos_example_str += "\n" + if ( + len( + self.tokenizer( + definition + + " ".join(pos_examples) + + pos_example_str + + task_input + )["input_ids"] + ) + <= self.max_source_length + ): + pos_examples.append(pos_example_str) + else: + break + + # try to add negative examples. + neg_examples = [] + for idx, neg_example in enumerate( + instance["Negative Examples"][:num_neg_examples] + ): + neg_example_str = f" Negative Example {idx+1} -\n" + neg_example_str += f"Input: {neg_example['input'].strip()}" + if not neg_example_str[-1] in string.punctuation: + neg_example_str += "." + neg_example_str += "\n" + neg_example_str += f" Output: {neg_example['output'].strip()}" + if not neg_example_str[-1] in string.punctuation: + neg_example_str += "." + neg_example_str += "\n" + if add_explanation and "explanation" in neg_example: + neg_example_str += ( + f" Explanation: {neg_example['explanation'].strip()}" + ) + if not neg_example_str[-1] in string.punctuation: + neg_example_str += "." + neg_example_str += "\n" + neg_example_str += "\n" + if ( + len( + self.tokenizer( + definition + + " ".join(pos_examples) + + " ".join(neg_examples) + + neg_example_str + + task_input + )["input_ids"] + ) + <= self.max_source_length + ): + neg_examples.append(neg_example_str) + else: + break + + source = ( + task_name + + definition + + "".join(pos_examples) + + "".join(neg_examples) + + task_input + ) + tokenized_source = self.tokenizer(source)["input_ids"] + if len(tokenized_source) <= self.max_source_length: + sources.append(source) + else: + sources.append( + self.tokenizer.decode( + tokenized_source[: self.max_source_length], + skip_special_tokens=True, + ) + ) + + if self.text_only: + model_inputs = {"inputs": sources} + else: + model_inputs = self.tokenizer( + sources, + max_length=self.max_source_length, + padding=self.padding, + return_tensors=self.return_tensors, + truncation=True, + pad_to_multiple_of=self.pad_to_multiple_of, + ) + + if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]: + # Randomly select one reference if multiple are provided. + labels = [random.choice(ex["Instance"]["output"]) for ex in batch] + if self.text_only: + model_inputs["labels"] = labels + else: + labels = self.tokenizer( + labels, + max_length=self.max_target_length, + padding=self.padding, + return_tensors=self.return_tensors, + truncation=True, + pad_to_multiple_of=self.pad_to_multiple_of, + ) + label_mask = labels["attention_mask"].bool() + model_inputs["labels"] = labels["input_ids"].masked_fill( + ~label_mask, self.label_pad_token_id + ) + else: + model_inputs["labels"] = None + + return model_inputs diff --git a/osuT5/utils/init_utils.py b/osuT5/utils/init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9295e6f6c5dff9abe44078fccca604957b42cb8f --- /dev/null +++ b/osuT5/utils/init_utils.py @@ -0,0 +1,41 @@ +import torch +import os + +from accelerate.utils import set_seed +from omegaconf import open_dict, DictConfig + + +def check_args_and_env(args: DictConfig) -> None: + assert args.optim.batch_size % args.optim.grad_acc == 0 + # Train log must happen before eval log + assert args.eval.every_steps % args.logging.every_steps == 0 + + if args.device == "gpu": + assert torch.cuda.is_available(), "We use GPU to train/eval the model" + + +def opti_flags(args: DictConfig) -> None: + # This lines reduce training step by 2.4x + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +def update_args_with_env_info(args: DictConfig) -> None: + with open_dict(args): + slurm_id = os.getenv("SLURM_JOB_ID") + + if slurm_id is not None: + args.slurm_id = slurm_id + else: + args.slurm_id = "none" + + args.working_dir = os.getcwd() + + +def setup_args(args: DictConfig) -> None: + check_args_and_env(args) + update_args_with_env_info(args) + opti_flags(args) + + if args.seed is not None: + set_seed(args.seed) diff --git a/osuT5/utils/log_utils.py b/osuT5/utils/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e336103227055057a570e57043c771a56806235 --- /dev/null +++ b/osuT5/utils/log_utils.py @@ -0,0 +1,45 @@ +from collections import defaultdict + +import numpy as np +import torch + + +class Averager: + def __init__(self): + self.reset() + + # noinspection PyAttributeOutsideInit + def reset(self): + self.total = {} + self.counter = {} + + def update(self, stats): + for key, value in stats.items(): + if key not in self.total: + if isinstance(value, torch.Tensor): + self.total[key] = value.sum() + self.counter[key] = value.numel() + elif isinstance(value, np.ndarray): + self.total[key] = value.sum() + self.counter[key] = value.size + else: + self.total[key] = value + self.counter[key] = 1 + else: + if isinstance(value, torch.Tensor): + self.total[key] = self.total[key] + value.sum() + self.counter[key] = self.counter[key] + value.numel() + elif isinstance(value, np.ndarray): + self.total[key] = self.total[key] + value.sum() + self.counter[key] = self.counter[key] + value.size + else: + self.total[key] = self.total[key] + value + self.counter[key] = self.counter[key] + 1 + + def average(self): + averaged_stats = { + key: (tot / self.counter[key]).item() if isinstance(tot, torch.Tensor) else tot / self.counter[key] for key, tot in self.total.items() + } + self.reset() + + return averaged_stats diff --git a/osuT5/utils/model_utils.py b/osuT5/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..027f84efa6e3270833aa013ff929f4547b5a1bd2 --- /dev/null +++ b/osuT5/utils/model_utils.py @@ -0,0 +1,126 @@ +import multiprocessing +import time +from multiprocessing.managers import Namespace + +import torch +import numpy as np +from omegaconf import DictConfig, open_dict +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import ( + LRScheduler, + SequentialLR, + LinearLR, + CosineAnnealingLR, +) + +from osuT5.model.osu_t import OsuT +from osuT5.tokenizer import Tokenizer + + +def get_shared_training_state() -> Namespace: + mgr = multiprocessing.Manager() + shared = mgr.Namespace() + shared.current_train_step = 1 + shared.current_epoch = 1 + shared.last_log = time.time() + shared.current_loss = np.Infinity + shared.best_loss = np.Infinity + return shared + + +def get_model(args: DictConfig, tokenizer: Tokenizer) -> OsuT: + model = OsuT(args, tokenizer) + return model + + +def get_tokenizer(args: DictConfig) -> Tokenizer: + return Tokenizer(args) + + +def get_optimizer(model: OsuT, args: DictConfig) -> Optimizer: + no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"] + + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": args.optim.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + + if args.optim.name == 'adamw': + from transformers import AdamW + optimizer = AdamW( + optimizer_grouped_parameters, + lr=args.optim.base_lr, + ) + elif args.optim.name == 'adamwscale': + from .copied_utils import AdamWScale + optimizer = AdamWScale( + optimizer_grouped_parameters, + lr=args.optim.base_lr, + ) + elif args.optim.name == 'adafactor': + from transformers import Adafactor + optimizer = Adafactor( + optimizer_grouped_parameters, + lr=args.optim.base_lr, + relative_step=False, + ) + else: + raise NotImplementedError + + return optimizer + + +def get_scheduler(optimizer: Optimizer, args: DictConfig) -> LRScheduler: + scheduler_p1 = LinearLR( + optimizer, + start_factor=0.5, + end_factor=1, + total_iters=args.optim.warmup_steps, + last_epoch=-1, + ) + + scheduler_p2 = CosineAnnealingLR( + optimizer, + T_max=args.optim.total_steps - args.optim.warmup_steps, + eta_min=args.optim.final_cosine, + ) + + scheduler = SequentialLR( + optimizer, + schedulers=[scheduler_p1, scheduler_p2], + milestones=[args.optim.warmup_steps], + ) + + return scheduler + + + +def worker_init_fn(worker_id: int) -> None: + """ + Give each dataloader a unique slice of the full dataset. + """ + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset # the dataset copy in this worker process + overall_start = dataset.start + overall_end = dataset.end + # configure the dataset to only process the split workload + per_worker = int( + np.ceil((overall_end - overall_start) / float(worker_info.num_workers)), + ) + dataset.start = overall_start + worker_id * per_worker + dataset.end = min(dataset.start + per_worker, overall_end) diff --git a/osuT5/utils/train_utils.py b/osuT5/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64c54070ffe9924c914f90daa56d8e0933c40228 --- /dev/null +++ b/osuT5/utils/train_utils.py @@ -0,0 +1,337 @@ +import glob +import os.path +import time +from multiprocessing.managers import Namespace + +import torch +import wandb +from accelerate import Accelerator +from accelerate.logging import get_logger +from omegaconf import DictConfig +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import DataLoader + +from osuT5.tokenizer import Tokenizer, EventType +from osuT5.model import OsuT +from .log_utils import Averager + +logger = get_logger(__name__) + + +def forward(model: OsuT, batch): + outputs = model(**batch) + loss = outputs.loss + + stats = {"loss": loss.detach()} + return loss, stats + + +def forward_eval(model: OsuT, batch): + outputs = model(**batch) + return outputs + + +def add_prefix(prefix: str, stats: dict[str, float]): + return {f"{prefix}/{k}": v for k, v in stats.items()} + + +def maybe_save_checkpoint(accelerator: Accelerator, args: DictConfig, shared: Namespace): + if ( + shared.current_train_step > args.optim.total_steps + or shared.current_train_step % args.checkpoint.every_steps == 0 + ): + if shared.current_loss < shared.best_loss: + shared.best_loss = shared.current_loss + is_best = True + else: + is_best = False + + output_dir = f"checkpoint-{shared.current_train_step}" + accelerator.wait_for_everyone() + # Saving T5 has an issue that safe serialization removes shared tensors and then the model can't be loaded. + accelerator.save_state(output_dir=output_dir, safe_serialization=False) + + wandb_tracker = accelerator.get_tracker("wandb") + if wandb_tracker is not None: + art = wandb.Artifact( + f"osuT5-{wandb.run.id}", + type="model", + metadata={ + "format": "accelerate", + "src_seq_len": args.data.src_seq_len, + "tgt_seq_len": args.data.tgt_seq_len, + "num_classes": args.data.num_classes, + "num_diff_classes": args.data.num_diff_classes, + "max_difficulty": args.data.max_diff, + "class_dropout_prob": args.data.class_dropout_prob, + "diff_dropout_prob": args.data.diff_dropout_prob, + "spectrogram": args.model.spectrogram, + "current_train_step": shared.current_train_step, + "current_epoch": shared.current_epoch, + "current_loss": shared.current_loss, + }, + ) + + for file in os.listdir(output_dir): + art.add_file(os.path.join(output_dir, file)) + + wandb.log_artifact(art, aliases=["best"] if is_best else None) + logger.info(f"Logged checkpoint to wandb: {art.name}") + + +def maybe_eval( + model: OsuT, + accelerator: Accelerator, + dataloader: DataLoader, + tokenizer: Tokenizer, + args: DictConfig, + shared: Namespace, +): + if ( + shared.current_train_step > args.optim.total_steps + or shared.current_train_step % args.eval.every_steps == 0 + ): + model.eval() + + with torch.no_grad(): + eval_model(model, accelerator, dataloader, tokenizer, args, shared) + + shared.last_log = time.time() + model.train() + + +def maybe_logging( + model: OsuT, + accelerator: Accelerator, + optimizer: Optimizer, + averager: Averager, + args: DictConfig, + shared: Namespace, +): + def extra_stats(args, shared, model, optimizer): + stats = {} + + if args.logging.weights_l2: + weights_l2 = ( + sum(p.detach().norm(2).item() ** 2 for p in model.parameters() if p.requires_grad) ** 0.5 + ) + stats["weights_l2"] = weights_l2 + + stats["lr"] = optimizer.param_groups[0]["lr"] + stats["seconds_per_step"] = ( + time.time() - shared.last_log + ) / args.logging.every_steps + + return stats + + if shared.current_train_step % args.logging.every_steps == 0: + stats = extra_stats(args, shared, model, optimizer) + + averager.update(stats) + averaged_stats = averager.average() + averaged_stats["epoch"] = shared.current_epoch + averaged_stats = add_prefix("train", averaged_stats) + accelerator.log(averaged_stats, step=shared.current_train_step) + averaged_stats["step"] = shared.current_train_step + logger.info(averaged_stats) + + shared.last_log = time.time() + + +def maybe_grad_clip_and_grad_calc( + model: OsuT, + accelerator: Accelerator, + args: DictConfig, +): + if args.optim.grad_clip > 0: + grad_l2 = accelerator.clip_grad_norm_( + parameters=model.parameters(), + max_norm=args.optim.grad_clip, + norm_type=2, + ).item() + else: + grad_l2 = None + + if args.logging.grad_l2: + if grad_l2 is None: + grad_l2 = ( + sum( + p.grad.detach().data.norm(2).item() ** 2 for p in model.parameters() + ) + ** 0.5 + ) + + return {"grad_l2": grad_l2} + else: + return {} + + +# noinspection PyUnresolvedReferences,PyTypeChecker +def eval_model( + model: OsuT, + accelerator: Accelerator, + dataloader: DataLoader, + tokenizer: Tokenizer, + args: DictConfig, + shared: Namespace, +): + shared.last_log = time.time() + averager = Averager() + + for batch_id, batch in enumerate(dataloader, start=1): + if batch_id == args.eval.steps * args.optim.grad_acc: + break + + # We can't use the beatmap idx of the test set because these are not known by the model + del batch["beatmap_idx"] + + outputs = forward_eval(model, batch) + + # Reduce loss over all processes + loss = outputs.loss + loss = accelerator.reduce(loss, reduction="mean") + + # Gether labels and predictions over all processes and drop duplicates + preds = torch.argmax(outputs.logits, dim=-1) + labels = batch["labels"] + accelerator.gather_for_metrics((preds, labels)) + + # Calculate accuracy metrics + stats = {"loss": loss.detach(), + "timing_acc": acc_range(preds, labels, tokenizer.event_start[EventType.TIME_SHIFT], + tokenizer.event_end[EventType.TIME_SHIFT]), + "spacing_acc": acc_range(preds, labels, tokenizer.event_start[EventType.DISTANCE], + tokenizer.event_end[EventType.DISTANCE]), + "other_acc": acc_range(preds, labels, tokenizer.event_end[EventType.DISTANCE], + tokenizer.event_end[EventType.DISTANCE] + tokenizer.vocab_size_out)} + + averager.update(stats) + + averager.update({"time": time.time() - shared.last_log}) + averaged_stats = averager.average() + averaged_stats = add_prefix("test", averaged_stats) + accelerator.log(averaged_stats, step=shared.current_train_step) + logger.info(averaged_stats) + + shared.current_loss = averaged_stats["test/loss"] + + +def acc_range(preds, labels, start_index, end_index): + index = (start_index <= labels) & (labels < end_index) + range_labels = labels[index] + range_preds = preds[index] + return (range_preds == range_labels).detach().float().cpu().numpy() + + +def train( + model: OsuT, + train_dataloader: DataLoader, + test_dataloader: DataLoader, + accelerator: Accelerator, + lr_scheduler: LRScheduler, + optimizer: Optimizer, + tokenizer: Tokenizer, + args: DictConfig, + shared: Namespace, + profiler=None, +): + model.train() + + train_averager = Averager() + + while shared.current_train_step <= args.optim.total_steps: + # In case there is a remainder from previous epoch, we need to reset the optimizer + optimizer.zero_grad(set_to_none=True) + + accelerator.print(f"Epoch {shared.current_epoch}") + + for batch_id, batch in enumerate(train_dataloader, start=1): + with accelerator.accumulate(model): + if shared.current_train_step > args.optim.total_steps: + break + + loss, stats = forward(model, batch) + + accelerator.backward(loss) + train_averager.update(stats) + + if accelerator.sync_gradients: + stats = maybe_grad_clip_and_grad_calc(model, accelerator, args) + train_averager.update(stats) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if profiler is not None: + profiler.step() + + if accelerator.sync_gradients: + maybe_logging(model, accelerator, optimizer, train_averager, args, shared) + maybe_eval(model, accelerator, test_dataloader, tokenizer, args, shared) + maybe_save_checkpoint(accelerator, args, shared) + + shared.current_train_step += 1 + + shared.current_epoch += 1 + + if not (args.profile.do_profile and args.profile.early_stop): + maybe_eval(model, accelerator, test_dataloader, tokenizer, args, shared) + maybe_save_checkpoint(accelerator, args, shared) + + accelerator.end_training() + + +def train_profiling( + model: OsuT, + train_dataloader: DataLoader, + test_dataloader: DataLoader, + accelerator: Accelerator, + lr_scheduler: LRScheduler, + optimizer: Optimizer, + tokenizer: Tokenizer, + args: DictConfig, + shared: Namespace, +): + tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler( + "./profiler_logs", worker_name=f"worker_{accelerator.process_index}") + + if args.profile.early_stop: + stop_step = (args.profile.wait + args.profile.warmup + args.profile.active) * args.profile.repeat / args.optim.grad_acc + args.optim.total_steps = shared.current_train_step + stop_step + + def on_trace_ready(trace): + tensorboard_trace_handler(trace) + wandb_tracker = accelerator.get_tracker("wandb") + if wandb_tracker is not None: + wandb.save(glob.glob(f"./profiler_logs/*.pt.trace.json")[0], base_path="profiler_logs") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=args.profile.wait, + warmup=args.profile.warmup, + active=args.profile.active, + repeat=args.profile.repeat, + ), + on_trace_ready=on_trace_ready, + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as p: + train( + model, + train_dataloader, + test_dataloader, + accelerator, + lr_scheduler, + optimizer, + tokenizer, + args, + shared, + p + ) diff --git a/osudiffusion/DiT-B-0700000.pt b/osudiffusion/DiT-B-0700000.pt new file mode 100644 index 0000000000000000000000000000000000000000..056aa63416c4fa0a5edbb9cd901d56e582d83514 --- /dev/null +++ b/osudiffusion/DiT-B-0700000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63c37eee09c5f4979e26ce3dbfe6e54643e2e3b668eded629538134878d822a3 +size 2726131935 diff --git a/osudiffusion/__init__.py b/osudiffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2e9a85b6c333bfae2df7c55212b723b0490256 --- /dev/null +++ b/osudiffusion/__init__.py @@ -0,0 +1,4 @@ +from .positional_embedding import timestep_embedding +from .data_loading import repeat_type +from .diffusion import create_diffusion +from .models import DiT, DiT_models diff --git a/osudiffusion/__pycache__/__init__.cpython-311.pyc b/osudiffusion/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2724ac34c7eb0d203040b84541b3367e09df02fc Binary files /dev/null and b/osudiffusion/__pycache__/__init__.cpython-311.pyc differ diff --git a/osudiffusion/__pycache__/__init__.cpython-39.pyc b/osudiffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c1c4b62c42523d54d71b98e7931b2251e553dba Binary files /dev/null and b/osudiffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/osudiffusion/__pycache__/data_loading.cpython-311.pyc b/osudiffusion/__pycache__/data_loading.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6893525547cf40fc443dad9295bb0672b824d0ef Binary files /dev/null and b/osudiffusion/__pycache__/data_loading.cpython-311.pyc differ diff --git a/osudiffusion/__pycache__/data_loading.cpython-39.pyc b/osudiffusion/__pycache__/data_loading.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51963ee388fa476ff237cdf38c0038b77ec3499e Binary files /dev/null and b/osudiffusion/__pycache__/data_loading.cpython-39.pyc differ diff --git a/osudiffusion/__pycache__/models.cpython-311.pyc b/osudiffusion/__pycache__/models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02aa112175c3cacdacba4ad7178bf6acb3a27187 Binary files /dev/null and b/osudiffusion/__pycache__/models.cpython-311.pyc differ diff --git a/osudiffusion/__pycache__/models.cpython-39.pyc b/osudiffusion/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d479216547d93696aa90eb1dbc4c614c2a03498 Binary files /dev/null and b/osudiffusion/__pycache__/models.cpython-39.pyc differ diff --git a/osudiffusion/__pycache__/positional_embedding.cpython-311.pyc b/osudiffusion/__pycache__/positional_embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04e9eb388b30b91b5b1ef4b216ad2266721fba9 Binary files /dev/null and b/osudiffusion/__pycache__/positional_embedding.cpython-311.pyc differ diff --git a/osudiffusion/__pycache__/positional_embedding.cpython-39.pyc b/osudiffusion/__pycache__/positional_embedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e775c9ab36bcfaeb5bdc2f33ed99d6465f0979e Binary files /dev/null and b/osudiffusion/__pycache__/positional_embedding.cpython-39.pyc differ diff --git a/osudiffusion/beatmap_idx.pickle b/osudiffusion/beatmap_idx.pickle new file mode 100644 index 0000000000000000000000000000000000000000..149d4aee867741362df43b96af97de439f0e0312 --- /dev/null +++ b/osudiffusion/beatmap_idx.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:329166bedf4a2d2cc5db82e5a3d7f841e52c4e29f462cdb37323c91cdd025a1d +size 421278 diff --git a/osudiffusion/data_loading.py b/osudiffusion/data_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..223223401e6bacc19289f1155e8284c01d4a1a99 --- /dev/null +++ b/osudiffusion/data_loading.py @@ -0,0 +1,568 @@ +import math +import os.path +import pickle +import random +from collections.abc import Callable +from datetime import timedelta +from pathlib import Path, PurePosixPath, PureWindowsPath +from typing import Optional + +import torch +from torch.utils.data import DataLoader, Dataset +from torch.utils.data import IterableDataset +import tqdm + +from .positional_embedding import offset_sequence_embedding +from .positional_embedding import position_sequence_embedding +from .positional_embedding import timestep_embedding +from slider import Position +from slider.beatmap import Beatmap +from slider.beatmap import HitObject +from slider.beatmap import Slider +from slider.beatmap import Spinner +from slider.curve import Catmull +from slider.curve import Linear +from slider.curve import MultiBezier +from slider.curve import Perfect + +playfield_size = torch.tensor((512, 384)) +feature_size = 19 + + +def create_datapoint(time: timedelta, pos: Position, datatype: int) -> torch.Tensor: + features = torch.zeros(19) + features[0] = pos.x + features[1] = pos.y + features[2] = time.total_seconds() * 1000 + features[datatype + 3] = 1 + + return features + + +def repeat_type(repeat: int) -> int: + if repeat < 4: + return repeat - 1 + elif repeat % 2 == 0: + return 3 + else: + return 4 + + +def append_control_points( + datapoints: list[torch.Tensor], + slider: Slider, + datatype: int, + duration: timedelta, +): + control_point_count = len(slider.curve.points) + + for i in range(1, control_point_count - 1): + time = slider.time + i / (control_point_count - 1) * duration + pos = slider.curve.points[i] + datapoints.append(create_datapoint(time, pos, datatype)) + + +def get_data(hitobj: HitObject) -> torch.Tensor: + if isinstance(hitobj, Slider) and len(hitobj.curve.points) < 100: + datapoints = [ + create_datapoint( + hitobj.time, + hitobj.position, + 5 if hitobj.new_combo else 4, + ), + ] + + assert hitobj.repeat >= 1 + duration: timedelta = (hitobj.end_time - hitobj.time) / hitobj.repeat + + if isinstance(hitobj.curve, Linear): + append_control_points(datapoints, hitobj, 9, duration) + elif isinstance(hitobj.curve, Catmull): + append_control_points(datapoints, hitobj, 8, duration) + elif isinstance(hitobj.curve, Perfect): + append_control_points(datapoints, hitobj, 7, duration) + elif isinstance(hitobj.curve, MultiBezier): + control_point_count = len(hitobj.curve.points) + + for i in range(1, control_point_count - 1): + time = hitobj.time + i / (control_point_count - 1) * duration + pos = hitobj.curve.points[i] + + if pos == hitobj.curve.points[i + 1]: + datapoints.append(create_datapoint(time, pos, 9)) + elif pos != hitobj.curve.points[i - 1]: + datapoints.append(create_datapoint(time, pos, 6)) + + datapoints.append( + create_datapoint(hitobj.time + duration, hitobj.curve.points[-1], 10), + ) + + slider_end_pos = hitobj.curve(1) + datapoints.append( + create_datapoint( + hitobj.end_time, + slider_end_pos, + 11 + repeat_type(hitobj.repeat), + ), + ) + + return torch.stack(datapoints, 0) + + if isinstance(hitobj, Spinner): + return torch.stack( + ( + create_datapoint(hitobj.time, hitobj.position, 2), + create_datapoint(hitobj.end_time, hitobj.position, 3), + ), + 0, + ) + + return create_datapoint( + hitobj.time, + hitobj.position, + 1 if hitobj.new_combo else 0, + ).unsqueeze(0) + + +def beatmap_to_sequence(beatmap: Beatmap) -> torch.Tensor: + # Get the hit objects + hit_objects = beatmap.hit_objects(stacking=False) + data_chunks = [get_data(ho) for ho in hit_objects] + + sequence = torch.concatenate(data_chunks, 0) + sequence = torch.swapaxes(sequence, 0, 1) + + return sequence.float() + + +def random_flip(seq: torch.Tensor) -> torch.Tensor: + if random.random() < 0.5: + seq[0] = 512 - seq[0] + if random.random() < 0.5: + seq[1] = 384 - seq[1] + return seq + + +def calc_distances(seq: torch.Tensor) -> torch.Tensor: + offset = torch.roll(seq[:2, :], 1, 1) + offset[0, 0] = 256 + offset[1, 0] = 192 + seq_d = torch.linalg.vector_norm(seq[:2, :] - offset, ord=2, dim=0) + return seq_d + + +def split_and_process_sequence( + seq: torch.Tensor, +) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], int]: + seq_d = calc_distances(seq) + # Augment and normalize positions for diffusion + seq_x = random_flip(seq[:2, :]) / playfield_size.unsqueeze(1) + seq_o = seq[2, :] + seq_c = torch.concatenate( + [ + timestep_embedding(seq_d, 128).T, + seq[3:, :], + ], + 0, + ) + + return (seq_x, seq_o, seq_c), seq.shape[1] + + +def split_and_process_sequence_no_augment( + seq: torch.Tensor, +) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], int]: + seq_d = calc_distances(seq) + # Augment and normalize positions for diffusion + seq_x = seq[:2, :] / playfield_size.to(seq.device).unsqueeze(1) + seq_o = seq[2, :] + seq_c = torch.concatenate( + [ + timestep_embedding(seq_d, 128).T, + seq[3:, :], + ], + 0, + ) + + return (seq_x, seq_o, seq_c), seq.shape[1] + + +def load_and_process_beatmap(beatmap: Beatmap): + seq = beatmap_to_sequence(beatmap) + return split_and_process_sequence(seq) + + +def window_and_relative_time(seq, s, e): + seq_x, seq_o, seq_c = seq + x = seq_x[:, s:e] + # Obscure the absolute time by normalizing to zero and adding a random offset between zero and the max period + # We do this to make sure the offset embedding utilizes the full range of values, which is also the case when sampling the model + o = seq_o[s:e] - seq_o[s] + random.random() * 100000 + c = seq_c[:, s:e] + + return x, o, c + + +class BeatmapDatasetIterable: + __slots__ = ( + "beatmap_files", + "beatmap_idx", + "seq_len", + "stride", + "index", + "current_idx", + "current_seq", + "current_seq_len", + "seq_index", + "seq_func", + "win_func", + ) + + def __init__( + self, + beatmap_files: list[str], + seq_len: int, + stride: int, + seq_func: Callable, + win_func: Callable, + ): + self.beatmap_files = beatmap_files + self.seq_len = seq_len + self.stride = stride + self.index = 0 + self.current_idx = 0 + self.current_seq = None + self.current_seq_len = -1 + self.seq_index = 0 + self.seq_func = seq_func + self.win_func = win_func + + def __iter__(self) -> "BeatmapDatasetIterable": + return self + + def __next__(self) -> tuple[any, int]: + while ( + self.current_seq is None + or self.seq_index + self.seq_len > self.current_seq_len + ): + if self.index >= len(self.beatmap_files): + raise StopIteration + + # Load the beatmap from file + beatmap_path = self.beatmap_files[self.index] + beatmap = Beatmap.from_path(beatmap_path) + + self.current_idx = int(os.path.basename(beatmap_path)[:6]) + self.current_seq, self.current_seq_len = self.seq_func(beatmap) + self.seq_index = random.randint(0, self.stride - 1) + self.index += 1 + + # Return the preprocessed hit objects as a sequence of overlapping windows + window = self.win_func( + self.current_seq, + self.seq_index, + self.seq_index + self.seq_len, + ) + self.seq_index += self.stride + return window, self.current_idx + + +class InterleavingBeatmapDatasetIterable: + __slots__ = ("workers", "cycle_length", "index") + + def __init__( + self, + beatmap_files: list[str], + iterable_factory: Callable, + cycle_length: int, + ): + per_worker = int(math.ceil(len(beatmap_files) / float(cycle_length))) + self.workers = [ + iterable_factory( + beatmap_files[ + i * per_worker: min(len(beatmap_files), (i + 1) * per_worker) + ] + ) + for i in range(cycle_length) + ] + self.cycle_length = cycle_length + self.index = 0 + + def __iter__(self) -> "InterleavingBeatmapDatasetIterable": + return self + + def __next__(self) -> tuple[any, int]: + num = len(self.workers) + for _ in range(num): + try: + self.index = self.index % len(self.workers) + item = self.workers[self.index].__next__() + self.index += 1 + return item + except StopIteration: + self.workers.remove(self.workers[self.index]) + raise StopIteration + + +class BeatmapDataset(IterableDataset): + def __init__( + self, + dataset_path: str, + start: int, + end: int, + iterable_factory: Callable, + cycle_length: int = 1, + shuffle: bool = False, + beatmap_files: Optional[list[str]] = None, + ): + super(BeatmapDataset).__init__() + self.dataset_path = dataset_path + self.start = start + self.end = end + self.iterable_factory = iterable_factory + self.cycle_length = cycle_length + self.shuffle = shuffle + self.beatmap_files = beatmap_files + + def _get_beatmap_files(self) -> list[str]: + if self.beatmap_files is not None: + return self.beatmap_files + + # Get a list of all beatmap files in the dataset path in the track index range between start and end + beatmap_files = [] + track_names = ["Track" + str(i).zfill(5) for i in range(self.start, self.end)] + for track_name in track_names: + for beatmap_file in os.listdir( + os.path.join(self.dataset_path, track_name, "beatmaps"), + ): + beatmap_files.append( + os.path.join( + self.dataset_path, + track_name, + "beatmaps", + beatmap_file, + ), + ) + + return beatmap_files + + def __iter__(self) -> InterleavingBeatmapDatasetIterable | BeatmapDatasetIterable: + beatmap_files = self._get_beatmap_files() + + if self.shuffle: + random.shuffle(beatmap_files) + + if self.cycle_length > 1: + return InterleavingBeatmapDatasetIterable( + beatmap_files, + self.iterable_factory, + self.cycle_length, + ) + + return self.iterable_factory(beatmap_files) + + +# Define a `worker_init_fn` that configures each dataset copy differently +def worker_init_fn(worker_id: int) -> None: + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset # the dataset copy in this worker process + overall_start = dataset.start + overall_end = dataset.end + # configure the dataset to only process the split workload + per_worker = int( + math.ceil((overall_end - overall_start) / float(worker_info.num_workers)), + ) + dataset.start = overall_start + worker_id * per_worker + dataset.end = min(dataset.start + per_worker, overall_end) + + +def get_beatmap_idx(name) -> dict[int, int]: + p = Path(__file__).with_name(name) + with p.open("rb") as f: + beatmap_idx = pickle.load(f) + return beatmap_idx + + +def get_beatmap_files(name: str, data_path: str) -> list[PurePosixPath]: + p = Path(name) + with p.open("rb") as f: + relative_beatmap_files = pickle.load(f) + beatmap_files = [PurePosixPath(data_path, *PureWindowsPath(f).parts) for f in relative_beatmap_files] + return beatmap_files + + +class BeatmapDatasetIterableFactory: + __slots__ = ("seq_len", "stride", "seq_func", "win_func") + + def __init__(self, seq_len, stride, seq_func, win_func): + self.seq_len = seq_len + self.stride = stride + self.seq_func = seq_func + self.win_func = win_func + + def __call__(self, *args, **kwargs): + beatmap_files = args[0] + return BeatmapDatasetIterable( + beatmap_files=beatmap_files, + seq_len=self.seq_len, + stride=self.stride, + seq_func=self.seq_func, + win_func=self.win_func, + ) + + +class CachedDataset(Dataset): + __slots__ = "cached_data" + + def __init__(self, cached_data): + self.cached_data = cached_data + + def __getitem__(self, index): + return self.cached_data[index] + + def __len__(self): + return len(self.cached_data) + + +def cache_dataset( + out_path: str, + dataset_path: str, + start: int, + end: int, + iterable_factory: Callable, + cycle_length=1, + beatmap_files: Optional[list[str]] = None, +): + dataset = BeatmapDataset( + dataset_path=dataset_path, + start=start, + end=end, + iterable_factory=iterable_factory, + cycle_length=cycle_length, + shuffle=False, + beatmap_files=beatmap_files, + ) + + print("Caching dataset...") + cached_data = [] + for datum in tqdm.tqdm(dataset): + cached_data.append(datum) + + torch.save(cached_data, out_path) + + +def get_cached_data_loader( + data_path: str, + batch_size: int = 1, + num_workers: int = 0, + shuffle: bool = False, + pin_memory: bool = False, + drop_last: bool = False, +): + cached_data = torch.load(data_path) + dataset = CachedDataset(cached_data) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + persistent_workers=num_workers > 0, + shuffle=shuffle, + ) + + return dataloader + + +def get_data_loader( + dataset_path: str, + start: int, + end: int, + iterable_factory: Callable, + cycle_length=1, + batch_size: int = 1, + num_workers: int = 0, + shuffle: bool = False, + pin_memory: bool = False, + drop_last: bool = False, + beatmap_files: Optional[list[str]] = None, +) -> DataLoader: + dataset = BeatmapDataset( + dataset_path=dataset_path, + start=start, + end=end, + iterable_factory=iterable_factory, + cycle_length=cycle_length, + shuffle=shuffle, + beatmap_files=beatmap_files, + ) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + worker_init_fn=worker_init_fn, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + persistent_workers=num_workers > 0, + ) + + return dataloader + + +def main(args): + dataloader = get_data_loader( + dataset_path=args.data_path, + start=0, + end=16291, + iterable_factory=BeatmapDatasetIterableFactory( + 128, + 16, + load_and_process_beatmap, + window_and_relative_time, + ), + cycle_length=1, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=False, + pin_memory=False, + drop_last=True, + ) + + if args.mode == "plotfirst": + import matplotlib.pyplot as plt + + for (x, o, c), y in dataloader: + x = torch.swapaxes(x, 1, 2) # (N, T, C) + c = torch.swapaxes(c, 1, 2) # (N, T, E) + print(x.shape, o.shape, c.shape, y.shape) + batch_pos_emb = position_sequence_embedding(x * playfield_size, 128) + print(batch_pos_emb.shape) + batch_offset_emb = offset_sequence_embedding(o / 10, 128) + print(batch_offset_emb.shape) + print(y) + + for j in range(args.batch_size): + fig, axs = plt.subplots(3, figsize=(5, 20)) + axs[0].imshow(batch_pos_emb[j]) + axs[1].imshow(batch_offset_emb[j]) + axs[2].imshow(c[j]) + print(y[j]) + plt.show() + break + elif args.mode == "benchmark": + for _ in tqdm.tqdm(dataloader, total=7000, smoothing=0.01): + pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--mode", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--num-workers", type=int, default=0) + args = parser.parse_args() + main(args) diff --git a/osudiffusion/diffusion/__init__.py b/osudiffusion/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5d6d271f2c56b7bdb56cc93d10182d5d947813 --- /dev/null +++ b/osudiffusion/diffusion/__init__.py @@ -0,0 +1,47 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +from . import gaussian_diffusion as gd +from .respace import space_timesteps +from .respace import SpacedDiffusion + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + use_l1=False, +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_L1 if use_l1 else gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.L1 if use_l1 else gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + # rescale_timesteps=rescale_timesteps, + ) diff --git a/osudiffusion/diffusion/__pycache__/__init__.cpython-311.pyc b/osudiffusion/diffusion/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a928c251aa6b33a94417e88f9592d858dd738873 Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/__init__.cpython-311.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/__init__.cpython-39.pyc b/osudiffusion/diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc10293667799612c9625c22dd17ac26371282d Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/diffusion_utils.cpython-311.pyc b/osudiffusion/diffusion/__pycache__/diffusion_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69088bc08cf52e1dfacb61c6a9cefb4241325b94 Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/diffusion_utils.cpython-311.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/diffusion_utils.cpython-39.pyc b/osudiffusion/diffusion/__pycache__/diffusion_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0ec356362073b495d9c824d3856bb1fc5fe9e0e Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/diffusion_utils.cpython-39.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc b/osudiffusion/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3906f5299aef1b316e15f0e6691a2f83735796c6 Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/osudiffusion/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df6725d7b7cb319971f060bd34dff3d80e3a0311 Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/respace.cpython-311.pyc b/osudiffusion/diffusion/__pycache__/respace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14be330a1c6ab4f52e7eef6f18b6693a8406ffc9 Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/respace.cpython-311.pyc differ diff --git a/osudiffusion/diffusion/__pycache__/respace.cpython-39.pyc b/osudiffusion/diffusion/__pycache__/respace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f566b7f8316e80a7c0d2a7c7ec20d3318ab33b6d Binary files /dev/null and b/osudiffusion/diffusion/__pycache__/respace.cpython-39.pyc differ diff --git a/osudiffusion/diffusion/diffusion_utils.py b/osudiffusion/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..136e7af68bab43ee7073da7a6556a80e46622b0d --- /dev/null +++ b/osudiffusion/diffusion/diffusion_utils.py @@ -0,0 +1,89 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = ( + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ) + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob( + normalized_x, + ) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/osudiffusion/diffusion/gaussian_diffusion.py b/osudiffusion/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c388cc6fe95bcdb280f19f229b40801b093efc2e --- /dev/null +++ b/osudiffusion/diffusion/gaussian_diffusion.py @@ -0,0 +1,963 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import enum +import math + +import numpy as np +import torch as th + +from .diffusion_utils import discretized_gaussian_log_likelihood +from .diffusion_utils import normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + L1 = enum.auto() + RESCALED_L1 = enum.auto() + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace( + beta_start, + beta_end, + warmup_time, + dtype=np.float64, + ) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, + beta_end, + num_diffusion_timesteps, + dtype=np.float64, + ) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, + 1, + num_diffusion_timesteps, + dtype=np.float64, + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = ( + np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + if len(self.posterior_variance) > 1 + else np.array([]) + ) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, + t, + x_start.shape, + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, + t, + x_t.shape, + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, + t, + x.shape, + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 2) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output), + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, + x_t=x, + t=t, + ) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], + x_t=x, + t=t, + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, + out, + x, + t, + model_kwargs=model_kwargs, + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, + model, + x_start, + x_t, + t, + clip_denoised=True, + model_kwargs=None, + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, + x_t=x_t, + t=t, + ) + out = self.p_mean_variance( + model, + x_t, + t, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + kl = normal_kl( + true_mean, + true_log_variance_clipped, + out["mean"], + out["log_variance"], + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, + means=out["mean"], + log_scales=0.5 * out["log_variance"], + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif ( + self.loss_type == LossType.MSE + or self.loss_type == LossType.RESCALED_MSE + or self.loss_type == LossType.L1 + or self.loss_type == LossType.RESCALED_L1 + ): + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if ( + self.loss_type == LossType.RESCALED_MSE + or self.loss_type == LossType.RESCALED_L1 + ): + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, + x_t=x_t, + t=t, + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + if self.loss_type == LossType.L1 or self.loss_type == LossType.RESCALED_L1: + terms["l1"] = mean_flat(th.abs(target - model_output)) + if "vb" in terms: + terms["loss"] = terms["l1"] + terms["vb"] + else: + terms["loss"] = terms["l1"] + else: + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, + logvar1=qt_log_variance, + mean2=0.0, + logvar2=0.0, + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/osudiffusion/diffusion/respace.py b/osudiffusion/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..eec8fbeab050b4f34920be36488a00b19a939fa6 --- /dev/null +++ b/osudiffusion/diffusion/respace.py @@ -0,0 +1,132 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride", + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}", + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, + model, + *args, + **kwargs, + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, + model, + *args, + **kwargs, + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/osudiffusion/diffusion/timestep_sampler.py b/osudiffusion/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ff35bc162792f7c6ee05f8ca1932b46d360d1a2e --- /dev/null +++ b/osudiffusion/diffusion/timestep_sampler.py @@ -0,0 +1,151 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +from abc import ABC +from abc import abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], + dtype=np.float64, + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/osudiffusion/models.py b/osudiffusion/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ed635e2a564ad36bb0153541c90ec4a224304b --- /dev/null +++ b/osudiffusion/models.py @@ -0,0 +1,431 @@ +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +from .positional_embedding import offset_sequence_embedding +from .positional_embedding import position_sequence_embedding +from .positional_embedding import timestep_embedding + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding( + num_classes + use_cfg_embedding, + hidden_size, + ) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = ( + torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + ) + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = nn.MultiheadAttention( + hidden_size, + num_heads=num_heads, + batch_first=True, + **block_kwargs, + ) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + # noinspection PyTypeChecker + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True), + ) + + def forward(self, x, c, attn_mask=None): + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(c).chunk(6, dim=1) + modulated = modulate(self.norm1(x), shift_msa, scale_msa) + x = ( + x + + gate_msa.unsqueeze(1) + * self.attn( + modulated, + modulated, + modulated, + need_weights=False, + attn_mask=attn_mask, + )[0] + ) + x = x + gate_mlp.unsqueeze(1) * self.mlp( + modulate(self.norm2(x), shift_mlp, scale_mlp), + ) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class FirstLayer(nn.Module): + """ + Embeds scalar positions into vector representation and concatenates context. + """ + + def __init__( + self, + hidden_size, + context_size, + in_channels, + frequency_embedding_size=128, + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + in_channels * frequency_embedding_size + + frequency_embedding_size + + context_size, + hidden_size, + bias=True, + ), + ) + self.frequency_embedding_size = frequency_embedding_size + self.playfield_size = nn.Parameter( + torch.tensor((512, 384), dtype=torch.float32), + requires_grad=False, + ) + + def forward(self, x, o, c): + x_freq = position_sequence_embedding( + x * self.playfield_size, + self.frequency_embedding_size, + ) + o_freq = offset_sequence_embedding(o / 10, self.frequency_embedding_size) + xoc = torch.concatenate((x_freq, o_freq, c), -1) + xoc_emb = self.mlp(xoc) + return xoc_emb + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + in_channels=2, + context_size=142, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.context_size = context_size + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.num_heads = num_heads + + self.xoc_embedder = FirstLayer(hidden_size, context_size, in_channels) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + + self.blocks = nn.ModuleList( + [ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) + for _ in range(depth) + ], + ) + self.final_layer = FinalLayer(hidden_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize position embedding MLP: + nn.init.normal_(self.xoc_embedder.mlp[0].weight, std=0.02) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, o, c, y, attn_mask=None): + """ + Forward pass of DiT. + x: (N, C, T) tensor of sequence inputs + t: (N) tensor of diffusion timesteps + o: (N, T) tensor of sequence offsets in milliseconds + c: (N, E, T) tensor of sequence context + y: (N) tensor of class labels + """ + x = torch.swapaxes(x, 1, 2) # (N, T, C) + c = torch.swapaxes(c, 1, 2) # (N, T, E) + x = self.xoc_embedder(x, o, c) # (N, T, D), where T = seq_len + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + b = t + y # (N, D) + for block in self.blocks: + x = block(x, b, attn_mask) # (N, T, D) + x = self.final_layer(x, b) # (N, T, out_channels) + x = torch.swapaxes(x, 1, 2) # (N, out_channels, T) + return x + + def forward_with_cfg(self, x, t, o, c, y, cfg_scale, attn_mask=None): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, o, c, y, attn_mask) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] + # eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], + axis=0, + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# DiT Configs # +################################################################################# + + +def DiT_XL(**kwargs: dict) -> DiT: + return DiT(depth=28, hidden_size=1152, num_heads=16, **kwargs) + + +def DiT_L(**kwargs: dict) -> DiT: + return DiT(depth=24, hidden_size=1024, num_heads=16, **kwargs) + + +def DiT_B(**kwargs: dict) -> DiT: + return DiT(depth=12, hidden_size=768, num_heads=12, **kwargs) + + +def DiT_S(**kwargs: dict) -> DiT: + return DiT(depth=12, hidden_size=384, num_heads=6, **kwargs) + + +DiT_models = { + "DiT-XL": DiT_XL, + "DiT-L": DiT_L, + "DiT-B": DiT_B, + "DiT-S": DiT_S, +} diff --git a/osudiffusion/positional_embedding.py b/osudiffusion/positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e0bb6f7fda4b1a81f87361f97347dd44a78d046f --- /dev/null +++ b/osudiffusion/positional_embedding.py @@ -0,0 +1,165 @@ +import math + +import torch + + +def encode_single(d_model, value, max_period=10000.0): + """ + :param d_model: dimension of the model + :param value: the value to encode + :param max_period: the maximum allowed value + :return: length*d_model position matrix + """ + if d_model % 2 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model), + ) + pe = torch.zeros(d_model) + div_term = torch.exp( + torch.arange(0, d_model, 2, dtype=torch.float) + * -(math.log(max_period) / d_model), + ) + pe[0::2] = torch.sin(value * div_term) + pe[1::2] = torch.cos(value * div_term) + + return pe + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + / half, + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def offset_sequence_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: an (N, T) Tensor of sequences of time offsets + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, T, dim) Tensor of positional embeddings. + """ + N, T = t.shape + flattened = torch.flatten(t) + embedding = timestep_embedding(flattened, dim, max_period) + return torch.reshape(embedding, (N, T, dim)) + + +def position_sequence_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: an (N, T, D) Tensor of sequences of D dimensional positions. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, T, D * dim) Tensor of positional embeddings. + """ + N, T, D = t.shape + flattened = torch.flatten(t) + embedding = timestep_embedding(flattened, dim, max_period) + return torch.reshape(embedding, (N, T, D * dim)) + + +def positionalencoding(d_model, values, max_period=10000.0): + """ + :param d_model: dimension of the model + :param values: the values to encode + :param max_period: the maximum allowed value + :return: length*d_model position matrix + """ + if d_model % 2 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model), + ) + pe = torch.zeros(len(values), d_model) + position = values.unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2, dtype=torch.float) + * -(math.log(max_period) / d_model), + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + return pe + + +def positionalencoding1d(d_model, length): + """ + :param d_model: dimension of the model + :param length: length of positions + :return: length*d_model position matrix + """ + if d_model % 2 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model), + ) + pe = torch.zeros(2, d_model) + position = torch.arange(-50, 50, 100).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model), + ) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe + + +def positionalencoding2d(d_model, height, width): + """ + :param d_model: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if d_model % 4 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(d_model), + ) + pe = torch.zeros(d_model, height, width) + # Each dimension use half of d_model + d_model = int(d_model / 2) + div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)) + pos_w = torch.arange(0.0, width).unsqueeze(1) + pos_h = torch.arange(0.0, height).unsqueeze(1) + pe[0:d_model:2, :, :] = ( + torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[1:d_model:2, :, :] = ( + torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[d_model::2, :, :] = ( + torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + pe[d_model + 1 :: 2, :, :] = ( + torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + + return pe + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + pe = positionalencoding(128, torch.tensor([-50, 50])) + plt.imshow(pe) + plt.show() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..20f36b3216760490bd7261fd443b66305e3ebf16 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +accelerate +pydub +nnAudio +PyYAML +transformers +tensorboard +slider==0.8.1 +torch_tb_profiler +hydra-core \ No newline at end of file