Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -35
- README.md +3 -0
- checkpoint/custom_checkpoint_0.pkl +3 -0
- checkpoint/pytorch_model.bin +3 -0
- configs/inference.yaml +65 -0
- configs/model/model.yaml +8 -0
- configs/model/t5_base.yaml +15 -0
- configs/model/t5_small.yaml +10 -0
- configs/model/t5_small_v4.yaml +7 -0
- configs/model/t5_small_v9.yaml +9 -0
- configs/model/whisper_base.yaml +6 -0
- inference.py +117 -0
- osuT5/__init__.py +0 -0
- osuT5/__pycache__/__init__.cpython-311.pyc +0 -0
- osuT5/__pycache__/__init__.cpython-39.pyc +0 -0
- osuT5/dataset/__init__.py +1 -0
- osuT5/dataset/__pycache__/__init__.cpython-311.pyc +0 -0
- osuT5/dataset/__pycache__/__init__.cpython-39.pyc +0 -0
- osuT5/dataset/__pycache__/data_utils.cpython-311.pyc +0 -0
- osuT5/dataset/__pycache__/data_utils.cpython-39.pyc +0 -0
- osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc +0 -0
- osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc +0 -0
- osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc +0 -0
- osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc +0 -0
- osuT5/dataset/data_utils.py +100 -0
- osuT5/dataset/osu_parser.py +184 -0
- osuT5/inference/__init__.py +4 -0
- osuT5/inference/__pycache__/__init__.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/__init__.cpython-39.pyc +0 -0
- osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/path_approximator.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/path_approximator.cpython-39.pyc +0 -0
- osuT5/inference/__pycache__/pipeline.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/pipeline.cpython-39.pyc +0 -0
- osuT5/inference/__pycache__/postprocessor.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/postprocessor.cpython-39.pyc +0 -0
- osuT5/inference/__pycache__/preprocessor.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/preprocessor.cpython-39.pyc +0 -0
- osuT5/inference/__pycache__/slider_path.cpython-311.pyc +0 -0
- osuT5/inference/__pycache__/slider_path.cpython-39.pyc +0 -0
- osuT5/inference/diffusion_pipeline.py +214 -0
- osuT5/inference/path_approximator.py +253 -0
- osuT5/inference/pipeline.py +338 -0
- osuT5/inference/postprocessor.py +322 -0
- osuT5/inference/preprocessor.py +58 -0
- osuT5/inference/slider_path.py +230 -0
- osuT5/inference/template.osu +54 -0
- osuT5/model/__init__.py +1 -0
- osuT5/model/__pycache__/__init__.cpython-311.pyc +0 -0
- osuT5/model/__pycache__/__init__.cpython-39.pyc +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,36 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
osuT5/inference/vale.mp3 filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
checkpoint/custom_checkpoint_0.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0494fdd396142b4a2919c0ab913502c9335746959156a08e82c2647235e07853
|
| 3 |
+
size 564880
|
checkpoint/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a12b6c312590efbdf5d7acaff6d8537e8ad1728737eebb43d0a43d5a4b3b5a3a
|
| 3 |
+
size 377860126
|
configs/inference.yaml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: 'google/t5-v1_1-small'
|
| 3 |
+
spectrogram:
|
| 4 |
+
sample_rate: 16000
|
| 5 |
+
hop_length: 128
|
| 6 |
+
n_fft: 1024
|
| 7 |
+
n_mels: 388
|
| 8 |
+
do_style_embed: false
|
| 9 |
+
input_features: false
|
| 10 |
+
|
| 11 |
+
model_path: './checkpoint'
|
| 12 |
+
audio_path: '' # Path to input audio
|
| 13 |
+
total_duration_ms: 0 # Total duration of audio in milliseconds, 0 for full audio
|
| 14 |
+
output_path: '' # Path to output directory
|
| 15 |
+
bpm: 120 # Beats per minute of input audio
|
| 16 |
+
offset: 0 # Start of beat, in miliseconds, from the beginning of input audio
|
| 17 |
+
resnap_objects: false # Resnap objects beat timing ticks, requires accurate BPM and offset
|
| 18 |
+
slider_multiplier: 1.7 # Multiplier for slider velocity
|
| 19 |
+
title: '' # Song title
|
| 20 |
+
artist: '' # Song artist
|
| 21 |
+
beatmap_path: '' # Path to .osu file which will be remapped
|
| 22 |
+
other_beatmap_path: '' # Path to .osu file of other beatmap in the mapset to use as reference
|
| 23 |
+
beatmap_id: -1 # Beatmap ID to use as style
|
| 24 |
+
difficulty: -1 # Difficulty star rating to map
|
| 25 |
+
creator: '' # Beatmap creator
|
| 26 |
+
version: '' # Beatmap version
|
| 27 |
+
full_set: true # Generate full mapset
|
| 28 |
+
set_difficulties: 5 # Number of difficulties to generate.
|
| 29 |
+
|
| 30 |
+
# Diffusion settings
|
| 31 |
+
generate_positions: true # Use diffusion to generate object positions
|
| 32 |
+
diff_ckpt: './osudiffusion/DiT-B-0700000.pt' # Path to checkpoint for diffusion model
|
| 33 |
+
diff_refine_ckpt: '' # Path to checkpoint for refining diffusion model
|
| 34 |
+
|
| 35 |
+
diffusion:
|
| 36 |
+
style_id: 1451282 # Style ID to use for diffusion
|
| 37 |
+
num_sampling_steps: 100 # Number of sampling steps
|
| 38 |
+
cfg_scale: 1 # Scale of classifier-free guidance
|
| 39 |
+
num_classes: 52670 # Number of classes stored in the model
|
| 40 |
+
beatmap_idx: 'osudiffusion/beatmap_idx.pickle' # Path to beatmap index
|
| 41 |
+
use_amp: true # Use automatic mixed precision
|
| 42 |
+
refine_iters: 10 # Number of refinement iterations
|
| 43 |
+
seq_len: 128 # Sequence length
|
| 44 |
+
model: 'DiT-B' # Model architecture
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
data: # Data settings
|
| 48 |
+
src_seq_len: 640
|
| 49 |
+
tgt_seq_len: 480
|
| 50 |
+
sample_rate: ${model.spectrogram.sample_rate}
|
| 51 |
+
hop_length: ${model.spectrogram.hop_length}
|
| 52 |
+
sequence_stride: 1 # Fraction of audio sequence length to shift inference window
|
| 53 |
+
center_pad_decoder: false # Center pad decoder input
|
| 54 |
+
add_pre_tokens: true
|
| 55 |
+
special_token_len: 2
|
| 56 |
+
diff_token_index: 0
|
| 57 |
+
style_token_index: -1
|
| 58 |
+
max_pre_token_len: 4
|
| 59 |
+
add_gd_context: false # Prefix the decoder with tokens of another beatmap in the mapset
|
| 60 |
+
|
| 61 |
+
hydra:
|
| 62 |
+
job:
|
| 63 |
+
chdir: False
|
| 64 |
+
run:
|
| 65 |
+
dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
configs/model/model.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input_features: false
|
| 2 |
+
do_style_embed: true
|
| 3 |
+
|
| 4 |
+
spectrogram:
|
| 5 |
+
sample_rate: 16000
|
| 6 |
+
hop_length: 128
|
| 7 |
+
n_fft: 1024
|
| 8 |
+
n_mels: 388
|
configs/model/t5_base.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'google/t5-v1_1-base'
|
| 6 |
+
overwrite:
|
| 7 |
+
dropout_rate: 0.0
|
| 8 |
+
|
| 9 |
+
spectrogram:
|
| 10 |
+
sample_rate: 16000
|
| 11 |
+
hop_length: 128
|
| 12 |
+
n_fft: 1024
|
| 13 |
+
n_mels: 388
|
| 14 |
+
do_style_embed: false
|
| 15 |
+
input_features: false
|
configs/model/t5_small.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'google/t5-v1_1-small'
|
| 6 |
+
overwrite:
|
| 7 |
+
dropout_rate: 0.0
|
| 8 |
+
|
| 9 |
+
spectrogram:
|
| 10 |
+
n_mels: 512
|
configs/model/t5_small_v4.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'google/t5-v1_1-small'
|
| 6 |
+
overwrite:
|
| 7 |
+
dropout_rate: 0.0
|
configs/model/t5_small_v9.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
do_style_embed: false
|
| 6 |
+
|
| 7 |
+
name: 'google/t5-v1_1-small'
|
| 8 |
+
overwrite:
|
| 9 |
+
dropout_rate: 0.0
|
configs/model/whisper_base.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'openai/whisper-base'
|
| 6 |
+
input_features: true
|
inference.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
import torch
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
from slider import Beatmap
|
| 7 |
+
|
| 8 |
+
from osudiffusion import DiT_models
|
| 9 |
+
from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipeline
|
| 10 |
+
from osuT5.tokenizer import Tokenizer
|
| 11 |
+
from osuT5.utils import get_model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_args_from_beatmap(args: DictConfig):
|
| 15 |
+
if args.beatmap_path is None or args.beatmap_path == "":
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
beatmap_path = Path(args.beatmap_path)
|
| 19 |
+
|
| 20 |
+
if not beatmap_path.is_file():
|
| 21 |
+
raise FileNotFoundError(f"Beatmap file {beatmap_path} not found.")
|
| 22 |
+
|
| 23 |
+
beatmap = Beatmap.from_path(beatmap_path)
|
| 24 |
+
args.audio_path = beatmap_path.parent / beatmap.audio_filename
|
| 25 |
+
args.output_path = beatmap_path.parent
|
| 26 |
+
args.bpm = beatmap.bpm_max()
|
| 27 |
+
args.offset = min(tp.offset.total_seconds() * 1000 for tp in beatmap.timing_points)
|
| 28 |
+
args.slider_multiplier = beatmap.slider_multiplier
|
| 29 |
+
args.title = beatmap.title
|
| 30 |
+
args.artist = beatmap.artist
|
| 31 |
+
args.beatmap_id = beatmap.beatmap_id if args.beatmap_id == -1 else args.beatmap_id
|
| 32 |
+
args.diffusion.style_id = beatmap.beatmap_id if args.diffusion.style_id == -1 else args.diffusion.style_id
|
| 33 |
+
args.difficulty = float(beatmap.stars()) if args.difficulty == -1 else args.difficulty
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def find_model(ckpt_path, args: DictConfig, device):
|
| 37 |
+
assert Path(ckpt_path).exists(), f"Could not find DiT checkpoint at {ckpt_path}"
|
| 38 |
+
checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
|
| 39 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
| 40 |
+
checkpoint = checkpoint["ema"]
|
| 41 |
+
|
| 42 |
+
model = DiT_models[args.diffusion.model](
|
| 43 |
+
num_classes=args.diffusion.num_classes,
|
| 44 |
+
context_size=19 - 3 + 128,
|
| 45 |
+
).to(device)
|
| 46 |
+
model.load_state_dict(checkpoint)
|
| 47 |
+
model.eval() # important!
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
|
| 52 |
+
def main(args: DictConfig):
|
| 53 |
+
get_args_from_beatmap(args)
|
| 54 |
+
|
| 55 |
+
torch.set_grad_enabled(False)
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
ckpt_path = Path(args.model_path)
|
| 58 |
+
model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device)
|
| 59 |
+
tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl")
|
| 60 |
+
|
| 61 |
+
tokenizer = Tokenizer()
|
| 62 |
+
tokenizer.load_state_dict(tokenizer_state)
|
| 63 |
+
|
| 64 |
+
model = get_model(args, tokenizer)
|
| 65 |
+
model.load_state_dict(model_state)
|
| 66 |
+
model.eval()
|
| 67 |
+
model.to(device)
|
| 68 |
+
|
| 69 |
+
preprocessor = Preprocessor(args)
|
| 70 |
+
audio = preprocessor.load(args.audio_path)
|
| 71 |
+
sequences = preprocessor.segment(audio)
|
| 72 |
+
total_duration_ms = len(audio) / 16000 * 1000
|
| 73 |
+
args.total_duration_ms = total_duration_ms
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
generated_maps = []
|
| 80 |
+
generated_positions = []
|
| 81 |
+
diffs = []
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if args.full_set:
|
| 85 |
+
for i in range(args.set_difficulties):
|
| 86 |
+
diffs.append(3 + i * (7 - 3) / (args.set_difficulties - 1))
|
| 87 |
+
|
| 88 |
+
print(diffs)
|
| 89 |
+
for diff in diffs:
|
| 90 |
+
print(f"Generating difficulty {diff}")
|
| 91 |
+
args.difficulty = diff
|
| 92 |
+
pipeline = Pipeline(args, tokenizer)
|
| 93 |
+
events = pipeline.generate(model, sequences)
|
| 94 |
+
generated_maps.append(events)
|
| 95 |
+
else:
|
| 96 |
+
pipeline = Pipeline(args, tokenizer)
|
| 97 |
+
events = pipeline.generate(model, sequences)
|
| 98 |
+
generated_maps.append(events)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if args.generate_positions:
|
| 103 |
+
model = find_model(args.diff_ckpt, args, device)
|
| 104 |
+
refine_model = find_model(args.diff_refine_ckpt, args, device) if len(args.diff_refine_ckpt) > 0 else None
|
| 105 |
+
diffusion_pipeline = DiffisionPipeline(args.diffusion)
|
| 106 |
+
for events in generated_maps:
|
| 107 |
+
events = diffusion_pipeline.generate(model, events, refine_model)
|
| 108 |
+
generated_positions.append(events)
|
| 109 |
+
else:
|
| 110 |
+
generated_positions = generated_maps
|
| 111 |
+
|
| 112 |
+
postprocessor = Postprocessor(args)
|
| 113 |
+
postprocessor.generate(generated_positions)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
osuT5/__init__.py
ADDED
|
File without changes
|
osuT5/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
osuT5/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (131 Bytes). View file
|
|
|
osuT5/dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .osu_parser import OsuParser
|
osuT5/dataset/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (277 Bytes). View file
|
|
|
osuT5/dataset/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (229 Bytes). View file
|
|
|
osuT5/dataset/__pycache__/data_utils.cpython-311.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
osuT5/dataset/__pycache__/data_utils.cpython-39.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
osuT5/dataset/data_utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pydub import AudioSegment
|
| 6 |
+
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
|
| 9 |
+
from osuT5.tokenizer import Event, EventType
|
| 10 |
+
|
| 11 |
+
MILISECONDS_PER_SECOND = 1000
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_audio_file(file: Path, sample_rate: int) -> npt.NDArray:
|
| 15 |
+
"""Load an audio file as a numpy time-series array
|
| 16 |
+
|
| 17 |
+
The signals are resampled, converted to mono channel, and normalized.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
file: Path to audio file.
|
| 21 |
+
sample_rate: Sample rate to resample the audio.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
samples: Audio time series.
|
| 25 |
+
"""
|
| 26 |
+
print(file)
|
| 27 |
+
audio = AudioSegment.from_file(file, format="mp3")
|
| 28 |
+
audio = audio.set_frame_rate(sample_rate)
|
| 29 |
+
audio = audio.set_channels(1)
|
| 30 |
+
samples = np.array(audio.get_array_of_samples()).astype(np.float32)
|
| 31 |
+
samples *= 1.0 / np.max(np.abs(samples))
|
| 32 |
+
return samples
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def update_event_times(events: list[Event], event_times: list[float], end_time: Optional[float] = None):
|
| 36 |
+
non_timed_events = [
|
| 37 |
+
EventType.BEZIER_ANCHOR,
|
| 38 |
+
EventType.PERFECT_ANCHOR,
|
| 39 |
+
EventType.CATMULL_ANCHOR,
|
| 40 |
+
EventType.RED_ANCHOR,
|
| 41 |
+
]
|
| 42 |
+
timed_events = [
|
| 43 |
+
EventType.CIRCLE,
|
| 44 |
+
EventType.SPINNER,
|
| 45 |
+
EventType.SPINNER_END,
|
| 46 |
+
EventType.SLIDER_HEAD,
|
| 47 |
+
EventType.LAST_ANCHOR,
|
| 48 |
+
EventType.SLIDER_END,
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
start_index = len(event_times)
|
| 52 |
+
end_index = len(events)
|
| 53 |
+
ct = 0 if len(event_times) == 0 else event_times[-1]
|
| 54 |
+
for i in range(start_index, end_index):
|
| 55 |
+
event = events[i]
|
| 56 |
+
if event.type == EventType.TIME_SHIFT:
|
| 57 |
+
ct = event.value
|
| 58 |
+
event_times.append(ct)
|
| 59 |
+
|
| 60 |
+
# Interpolate time for control point events
|
| 61 |
+
# T-D-Start-D-CP-D-CP-T-D-LCP-T-D-End
|
| 62 |
+
# 1-1-1-----1-1--1-1--7-7--7--9-9-9--
|
| 63 |
+
# 1-1-1-----3-3--5-5--7-7--7--9-9-9--
|
| 64 |
+
ct = end_time if end_time is not None else event_times[-1]
|
| 65 |
+
interpolate = False
|
| 66 |
+
for i in range(end_index - 1, start_index - 1, -1):
|
| 67 |
+
event = events[i]
|
| 68 |
+
|
| 69 |
+
if event.type in timed_events:
|
| 70 |
+
interpolate = False
|
| 71 |
+
|
| 72 |
+
if event.type in non_timed_events:
|
| 73 |
+
interpolate = True
|
| 74 |
+
|
| 75 |
+
if not interpolate:
|
| 76 |
+
ct = event_times[i]
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
if event.type not in non_timed_events:
|
| 80 |
+
event_times[i] = ct
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
# Find the time of the first timed event and the number of control points between
|
| 84 |
+
j = i
|
| 85 |
+
count = 0
|
| 86 |
+
t = ct
|
| 87 |
+
while j >= 0:
|
| 88 |
+
event2 = events[j]
|
| 89 |
+
if event2.type == EventType.TIME_SHIFT:
|
| 90 |
+
t = event_times[j]
|
| 91 |
+
break
|
| 92 |
+
if event2.type in non_timed_events:
|
| 93 |
+
count += 1
|
| 94 |
+
j -= 1
|
| 95 |
+
if i < 0:
|
| 96 |
+
t = 0
|
| 97 |
+
|
| 98 |
+
# Interpolate the time
|
| 99 |
+
ct = (ct - t) / (count + 1) * count + t
|
| 100 |
+
event_times[i] = ct
|
osuT5/dataset/osu_parser.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import numpy.typing as npt
|
| 7 |
+
from slider import Beatmap, Circle, Slider, Spinner
|
| 8 |
+
from slider.curve import Linear, Catmull, Perfect, MultiBezier
|
| 9 |
+
|
| 10 |
+
from osuT5.tokenizer import Event, EventType, Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OsuParser:
|
| 14 |
+
def __init__(self, tokenizer: Tokenizer) -> None:
|
| 15 |
+
dist_range = tokenizer.event_range[EventType.DISTANCE]
|
| 16 |
+
self.dist_min = dist_range.min_value
|
| 17 |
+
self.dist_max = dist_range.max_value
|
| 18 |
+
|
| 19 |
+
def parse(self, beatmap: Beatmap) -> list[Event]:
|
| 20 |
+
# noinspection PyUnresolvedReferences
|
| 21 |
+
"""Parse an .osu beatmap.
|
| 22 |
+
|
| 23 |
+
Each hit object is parsed into a list of Event objects, in order of its
|
| 24 |
+
appearance in the beatmap. In other words, in ascending order of time.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
beatmap: Beatmap object parsed from an .osu file.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
events: List of Event object lists.
|
| 31 |
+
|
| 32 |
+
Example::
|
| 33 |
+
>>> beatmap = [
|
| 34 |
+
"64,80,11000,1,0",
|
| 35 |
+
"100,100,16000,2,0,B|200:200|250:200|250:200|300:150,2"
|
| 36 |
+
]
|
| 37 |
+
>>> events = parse(beatmap)
|
| 38 |
+
>>> print(events)
|
| 39 |
+
[
|
| 40 |
+
Event(EventType.TIME_SHIFT, 11000), Event(EventType.DISTANCE, 36), Event(EventType.CIRCLE),
|
| 41 |
+
Event(EventType.TIME_SHIFT, 16000), Event(EventType.DISTANCE, 42), Event(EventType.SLIDER_HEAD),
|
| 42 |
+
Event(EventType.TIME_SHIFT, 16500), Event(EventType.DISTANCE, 141), Event(EventType.BEZIER_ANCHOR),
|
| 43 |
+
Event(EventType.TIME_SHIFT, 17000), Event(EventType.DISTANCE, 50), Event(EventType.BEZIER_ANCHOR),
|
| 44 |
+
Event(EventType.TIME_SHIFT, 17500), Event(EventType.DISTANCE, 10), Event(EventType.BEZIER_ANCHOR),
|
| 45 |
+
Event(EventType.TIME_SHIFT, 18000), Event(EventType.DISTANCE, 64), Event(EventType.LAST _ANCHOR),
|
| 46 |
+
Event(EventType.TIME_SHIFT, 20000), Event(EventType.DISTANCE, 11), Event(EventType.SLIDER_END)
|
| 47 |
+
]
|
| 48 |
+
"""
|
| 49 |
+
hit_objects = beatmap.hit_objects(stacking=False)
|
| 50 |
+
last_pos = np.array((256, 192))
|
| 51 |
+
events = []
|
| 52 |
+
|
| 53 |
+
for hit_object in hit_objects:
|
| 54 |
+
if isinstance(hit_object, Circle):
|
| 55 |
+
last_pos = self._parse_circle(hit_object, events, last_pos)
|
| 56 |
+
elif isinstance(hit_object, Slider):
|
| 57 |
+
last_pos = self._parse_slider(hit_object, events, last_pos)
|
| 58 |
+
elif isinstance(hit_object, Spinner):
|
| 59 |
+
last_pos = self._parse_spinner(hit_object, events)
|
| 60 |
+
|
| 61 |
+
return events
|
| 62 |
+
|
| 63 |
+
def _clip_dist(self, dist: int) -> int:
|
| 64 |
+
"""Clip distance to valid range."""
|
| 65 |
+
return int(np.clip(dist, self.dist_min, self.dist_max))
|
| 66 |
+
|
| 67 |
+
def _parse_circle(self, circle: Circle, events: list[Event], last_pos: npt.NDArray) -> npt.NDArray:
|
| 68 |
+
"""Parse a circle hit object.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
circle: Circle object.
|
| 72 |
+
events: List of events to add to.
|
| 73 |
+
last_pos: Last position of the hit objects.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
pos: Position of the circle.
|
| 77 |
+
"""
|
| 78 |
+
time = int(circle.time.total_seconds() * 1000)
|
| 79 |
+
pos = np.array(circle.position)
|
| 80 |
+
dist = self._clip_dist(np.linalg.norm(pos - last_pos))
|
| 81 |
+
|
| 82 |
+
events.append(Event(EventType.TIME_SHIFT, time))
|
| 83 |
+
events.append(Event(EventType.DISTANCE, dist))
|
| 84 |
+
if circle.new_combo:
|
| 85 |
+
events.append(Event(EventType.NEW_COMBO))
|
| 86 |
+
events.append(Event(EventType.CIRCLE))
|
| 87 |
+
|
| 88 |
+
return pos
|
| 89 |
+
|
| 90 |
+
def _parse_slider(self, slider: Slider, events: list[Event], last_pos: npt.NDArray) -> npt.NDArray:
|
| 91 |
+
"""Parse a slider hit object.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
slider: Slider object.
|
| 95 |
+
events: List of events to add to.
|
| 96 |
+
last_pos: Last position of the hit objects.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
pos: Last position of the slider.
|
| 100 |
+
"""
|
| 101 |
+
# Ignore sliders which are too big
|
| 102 |
+
if len(slider.curve.points) >= 100:
|
| 103 |
+
return last_pos
|
| 104 |
+
|
| 105 |
+
time = int(slider.time.total_seconds() * 1000)
|
| 106 |
+
pos = np.array(slider.position)
|
| 107 |
+
dist = self._clip_dist(np.linalg.norm(pos - last_pos))
|
| 108 |
+
last_pos = pos
|
| 109 |
+
|
| 110 |
+
events.append(Event(EventType.TIME_SHIFT, time))
|
| 111 |
+
events.append(Event(EventType.DISTANCE, dist))
|
| 112 |
+
if slider.new_combo:
|
| 113 |
+
events.append(Event(EventType.NEW_COMBO))
|
| 114 |
+
events.append(Event(EventType.SLIDER_HEAD))
|
| 115 |
+
|
| 116 |
+
duration: timedelta = (slider.end_time - slider.time) / slider.repeat
|
| 117 |
+
control_point_count = len(slider.curve.points)
|
| 118 |
+
|
| 119 |
+
def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray:
|
| 120 |
+
for i in range(1, control_point_count - 1):
|
| 121 |
+
last_pos = add_anchor_time_dist(i, last_pos)
|
| 122 |
+
events.append(Event(event_type))
|
| 123 |
+
|
| 124 |
+
return last_pos
|
| 125 |
+
|
| 126 |
+
def add_anchor_time_dist(i: int, last_pos: npt.NDArray) -> npt.NDArray:
|
| 127 |
+
time = int((slider.time + i / (control_point_count - 1) * duration).total_seconds() * 1000)
|
| 128 |
+
pos = np.array(slider.curve.points[i])
|
| 129 |
+
dist = self._clip_dist(np.linalg.norm(pos - last_pos))
|
| 130 |
+
last_pos = pos
|
| 131 |
+
|
| 132 |
+
events.append(Event(EventType.TIME_SHIFT, time))
|
| 133 |
+
events.append(Event(EventType.DISTANCE, dist))
|
| 134 |
+
|
| 135 |
+
return last_pos
|
| 136 |
+
|
| 137 |
+
if isinstance(slider.curve, Linear):
|
| 138 |
+
last_pos = append_control_points(EventType.RED_ANCHOR, last_pos)
|
| 139 |
+
elif isinstance(slider.curve, Catmull):
|
| 140 |
+
last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos)
|
| 141 |
+
elif isinstance(slider.curve, Perfect):
|
| 142 |
+
last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos)
|
| 143 |
+
elif isinstance(slider.curve, MultiBezier):
|
| 144 |
+
for i in range(1, control_point_count - 1):
|
| 145 |
+
if slider.curve.points[i] == slider.curve.points[i + 1]:
|
| 146 |
+
last_pos = add_anchor_time_dist(i, last_pos)
|
| 147 |
+
events.append(Event(EventType.RED_ANCHOR))
|
| 148 |
+
elif slider.curve.points[i] != slider.curve.points[i - 1]:
|
| 149 |
+
last_pos = add_anchor_time_dist(i, last_pos)
|
| 150 |
+
events.append(Event(EventType.BEZIER_ANCHOR))
|
| 151 |
+
|
| 152 |
+
last_pos = add_anchor_time_dist(control_point_count - 1, last_pos)
|
| 153 |
+
events.append(Event(EventType.LAST_ANCHOR))
|
| 154 |
+
|
| 155 |
+
time = int(slider.end_time.total_seconds() * 1000)
|
| 156 |
+
pos = np.array(slider.curve(1))
|
| 157 |
+
dist = self._clip_dist(np.linalg.norm(pos - last_pos))
|
| 158 |
+
last_pos = pos
|
| 159 |
+
|
| 160 |
+
events.append(Event(EventType.TIME_SHIFT, time))
|
| 161 |
+
events.append(Event(EventType.DISTANCE, dist))
|
| 162 |
+
events.append(Event(EventType.SLIDER_END))
|
| 163 |
+
|
| 164 |
+
return last_pos
|
| 165 |
+
|
| 166 |
+
def _parse_spinner(self, spinner: Spinner, events: list[Event]) -> npt.NDArray:
|
| 167 |
+
"""Parse a spinner hit object.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
spinner: Spinner object.
|
| 171 |
+
events: List of events to add to.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
pos: Last position of the spinner.
|
| 175 |
+
"""
|
| 176 |
+
time = int(spinner.time.total_seconds() * 1000)
|
| 177 |
+
events.append(Event(EventType.TIME_SHIFT, time))
|
| 178 |
+
events.append(Event(EventType.SPINNER))
|
| 179 |
+
|
| 180 |
+
time = int(spinner.end_time.total_seconds() * 1000)
|
| 181 |
+
events.append(Event(EventType.TIME_SHIFT, time))
|
| 182 |
+
events.append(Event(EventType.SPINNER_END))
|
| 183 |
+
|
| 184 |
+
return np.array((256, 192))
|
osuT5/inference/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline import *
|
| 2 |
+
from .preprocessor import *
|
| 3 |
+
from .postprocessor import *
|
| 4 |
+
from .diffusion_pipeline import *
|
osuT5/inference/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (301 Bytes). View file
|
|
|
osuT5/inference/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (245 Bytes). View file
|
|
|
osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
osuT5/inference/__pycache__/path_approximator.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
osuT5/inference/__pycache__/path_approximator.cpython-39.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
osuT5/inference/__pycache__/pipeline.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
osuT5/inference/__pycache__/pipeline.cpython-39.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
osuT5/inference/__pycache__/postprocessor.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
osuT5/inference/__pycache__/postprocessor.cpython-39.pyc
ADDED
|
Binary file (8.08 kB). View file
|
|
|
osuT5/inference/__pycache__/preprocessor.cpython-311.pyc
ADDED
|
Binary file (3.37 kB). View file
|
|
|
osuT5/inference/__pycache__/preprocessor.cpython-39.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
osuT5/inference/__pycache__/slider_path.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
osuT5/inference/__pycache__/slider_path.cpython-39.pyc
ADDED
|
Binary file (5.16 kB). View file
|
|
|
osuT5/inference/diffusion_pipeline.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from osudiffusion import timestep_embedding
|
| 9 |
+
from osudiffusion import repeat_type
|
| 10 |
+
from osudiffusion import create_diffusion
|
| 11 |
+
from osudiffusion import DiT
|
| 12 |
+
from osuT5.dataset.data_utils import update_event_times
|
| 13 |
+
from osuT5.tokenizer import Event, EventType
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_beatmap_idx(path) -> dict[int, int]:
|
| 17 |
+
p = Path(path)
|
| 18 |
+
with p.open("rb") as f:
|
| 19 |
+
beatmap_idx = pickle.load(f)
|
| 20 |
+
return beatmap_idx
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiffisionPipeline(object):
|
| 24 |
+
def __init__(self, args: DictConfig):
|
| 25 |
+
"""Model inference stage that generates positions for distance events."""
|
| 26 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
self.num_sampling_steps = args.num_sampling_steps
|
| 28 |
+
self.cfg_scale = args.cfg_scale
|
| 29 |
+
self.seq_len = args.seq_len
|
| 30 |
+
self.num_classes = args.num_classes
|
| 31 |
+
self.beatmap_idx = get_beatmap_idx(args.beatmap_idx)
|
| 32 |
+
self.style_id = args.style_id
|
| 33 |
+
self.refine_iters = args.refine_iters
|
| 34 |
+
self.use_amp = args.use_amp
|
| 35 |
+
|
| 36 |
+
if self.style_id in self.beatmap_idx:
|
| 37 |
+
self.class_label = self.beatmap_idx[self.style_id]
|
| 38 |
+
else:
|
| 39 |
+
print(f"Beatmap ID {self.style_id} not found in dataset, using default style.")
|
| 40 |
+
self.class_label = self.num_classes
|
| 41 |
+
|
| 42 |
+
def generate(self, model: DiT, events: list[Event], refine_model: DiT = None) -> list[Event]:
|
| 43 |
+
"""Generate position events for distance events in the Event list.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
model: Trained model to use for inference.
|
| 47 |
+
events: List of Event objects with distance events.
|
| 48 |
+
refine_model: Optional model to refine the generated positions.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
events: List of Event objects with position events.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
seq_o, seq_c, seq_len, seq_indices = self.events_to_sequence(events)
|
| 55 |
+
|
| 56 |
+
seq_o = seq_o - seq_o[0] # Normalize to relative time
|
| 57 |
+
print(f"seq len {seq_len}")
|
| 58 |
+
|
| 59 |
+
diffusion = create_diffusion(
|
| 60 |
+
str(self.num_sampling_steps),
|
| 61 |
+
noise_schedule="squaredcos_cap_v2",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Create banded matrix attention mask for increased sequence length
|
| 65 |
+
attn_mask = torch.full((seq_len, seq_len), True, dtype=torch.bool, device=self.device)
|
| 66 |
+
for i in range(seq_len):
|
| 67 |
+
attn_mask[max(0, i - self.seq_len): min(seq_len, i + self.seq_len), i] = False
|
| 68 |
+
|
| 69 |
+
class_labels = [self.class_label]
|
| 70 |
+
|
| 71 |
+
# Create sampling noise:
|
| 72 |
+
n = len(class_labels)
|
| 73 |
+
z = torch.randn(n, 2, seq_len, device=self.device)
|
| 74 |
+
o = seq_o.repeat(n, 1).to(self.device)
|
| 75 |
+
c = seq_c.repeat(n, 1, 1).to(self.device)
|
| 76 |
+
y = torch.tensor(class_labels, device=self.device)
|
| 77 |
+
|
| 78 |
+
# Setup classifier-free guidance:
|
| 79 |
+
z = torch.cat([z, z], 0)
|
| 80 |
+
o = torch.cat([o, o], 0)
|
| 81 |
+
c = torch.cat([c, c], 0)
|
| 82 |
+
y_null = torch.tensor([self.num_classes] * n, device=self.device)
|
| 83 |
+
y = torch.cat([y, y_null], 0)
|
| 84 |
+
model_kwargs = dict(o=o, c=c, y=y, cfg_scale=self.cfg_scale, attn_mask=attn_mask)
|
| 85 |
+
|
| 86 |
+
def to_positions(samples):
|
| 87 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 88 |
+
samples *= torch.tensor((512, 384), device=self.device).repeat(n, 1).unsqueeze(2)
|
| 89 |
+
return samples.cpu()
|
| 90 |
+
|
| 91 |
+
# Sample images:
|
| 92 |
+
samples = diffusion.p_sample_loop(
|
| 93 |
+
model.forward_with_cfg,
|
| 94 |
+
z.shape,
|
| 95 |
+
z,
|
| 96 |
+
clip_denoised=True,
|
| 97 |
+
model_kwargs=model_kwargs,
|
| 98 |
+
progress=True,
|
| 99 |
+
device=self.device,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if refine_model is not None:
|
| 103 |
+
# Refine result with refine model
|
| 104 |
+
for _ in tqdm(range(self.refine_iters)):
|
| 105 |
+
t = torch.tensor([0] * samples.shape[0], device=self.device)
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
out = diffusion.p_sample(
|
| 108 |
+
model.forward_with_cfg,
|
| 109 |
+
samples,
|
| 110 |
+
t,
|
| 111 |
+
clip_denoised=True,
|
| 112 |
+
model_kwargs=model_kwargs,
|
| 113 |
+
)
|
| 114 |
+
samples = out["sample"]
|
| 115 |
+
|
| 116 |
+
positions = to_positions(samples)
|
| 117 |
+
return self.events_with_pos(events, positions.squeeze(0), seq_indices)
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def events_to_sequence(events: list[Event]) -> tuple[torch.Tensor, torch.Tensor, int, dict[int, int]]:
|
| 121 |
+
# Calculate the time of every event and interpolate time for control point events
|
| 122 |
+
event_times = []
|
| 123 |
+
update_event_times(events, event_times)
|
| 124 |
+
|
| 125 |
+
# Calculate the number of repeats for each slider end event
|
| 126 |
+
# Convert to vectorized form for osu-diffusion
|
| 127 |
+
nc_types = [EventType.CIRCLE, EventType.SLIDER_HEAD]
|
| 128 |
+
event_index = {
|
| 129 |
+
EventType.CIRCLE: 0,
|
| 130 |
+
EventType.SPINNER: 2,
|
| 131 |
+
EventType.SPINNER_END: 3,
|
| 132 |
+
EventType.SLIDER_HEAD: 4,
|
| 133 |
+
EventType.BEZIER_ANCHOR: 6,
|
| 134 |
+
EventType.PERFECT_ANCHOR: 7,
|
| 135 |
+
EventType.CATMULL_ANCHOR: 8,
|
| 136 |
+
EventType.RED_ANCHOR: 9,
|
| 137 |
+
EventType.LAST_ANCHOR: 10,
|
| 138 |
+
EventType.SLIDER_END: 11,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
seq_indices = {}
|
| 142 |
+
indices = []
|
| 143 |
+
data_chunks = []
|
| 144 |
+
distance = 0
|
| 145 |
+
new_combo = False
|
| 146 |
+
head_time = 0
|
| 147 |
+
last_anchor_time = 0
|
| 148 |
+
for i, event in enumerate(events):
|
| 149 |
+
indices.append(i)
|
| 150 |
+
if event.type == EventType.DISTANCE:
|
| 151 |
+
distance = event.value
|
| 152 |
+
elif event.type == EventType.NEW_COMBO:
|
| 153 |
+
new_combo = True
|
| 154 |
+
elif event.type in event_index:
|
| 155 |
+
time = event_times[i]
|
| 156 |
+
index = event_index[event.type]
|
| 157 |
+
|
| 158 |
+
# Handle NC index offset
|
| 159 |
+
if event.type in nc_types and new_combo:
|
| 160 |
+
index += 1
|
| 161 |
+
new_combo = False
|
| 162 |
+
|
| 163 |
+
# Add slider end repeats index offset
|
| 164 |
+
if event.type == EventType.SLIDER_END:
|
| 165 |
+
span_duration = last_anchor_time - head_time
|
| 166 |
+
total_duration = time - head_time
|
| 167 |
+
repeats = max(int(round(total_duration / span_duration)), 1) if span_duration > 0 else 1
|
| 168 |
+
index += repeat_type(repeats)
|
| 169 |
+
elif event.type == EventType.SLIDER_HEAD:
|
| 170 |
+
head_time = time
|
| 171 |
+
elif event.type == EventType.LAST_ANCHOR:
|
| 172 |
+
last_anchor_time = time
|
| 173 |
+
|
| 174 |
+
features = torch.zeros(18)
|
| 175 |
+
features[0] = time
|
| 176 |
+
features[1] = distance
|
| 177 |
+
features[index + 2] = 1
|
| 178 |
+
data_chunks.append(features)
|
| 179 |
+
|
| 180 |
+
for j in indices:
|
| 181 |
+
seq_indices[j] = len(data_chunks) - 1
|
| 182 |
+
indices = []
|
| 183 |
+
|
| 184 |
+
seq = torch.stack(data_chunks, 0)
|
| 185 |
+
seq = torch.swapaxes(seq, 0, 1)
|
| 186 |
+
seq_o = seq[0, :]
|
| 187 |
+
seq_d = seq[1, :]
|
| 188 |
+
seq_c = torch.concatenate(
|
| 189 |
+
[
|
| 190 |
+
timestep_embedding(seq_d, 128).T,
|
| 191 |
+
seq[2:, :],
|
| 192 |
+
],
|
| 193 |
+
0,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return seq_o, seq_c, seq.shape[1], seq_indices
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def events_with_pos(events: list[Event], sampled_seq: torch.Tensor, seq_indices: dict[int, int]) -> list[Event]:
|
| 201 |
+
new_events = []
|
| 202 |
+
for i, event in enumerate(events):
|
| 203 |
+
if event.type == EventType.DISTANCE:
|
| 204 |
+
try:
|
| 205 |
+
index = seq_indices[i]
|
| 206 |
+
pos_x = sampled_seq[0, index].item()
|
| 207 |
+
pos_y = sampled_seq[1, index].item()
|
| 208 |
+
new_events.append(Event(EventType.POS_X, int(round(pos_x))))
|
| 209 |
+
new_events.append(Event(EventType.POS_Y, int(round(pos_y))))
|
| 210 |
+
except KeyError:
|
| 211 |
+
print(f"Warning: Key {i} not found in seq_indices. Skipping event.")
|
| 212 |
+
else:
|
| 213 |
+
new_events.append(event)
|
| 214 |
+
return new_events
|
osuT5/inference/path_approximator.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
BEZIER_TOLERANCE = 0.25
|
| 4 |
+
CATMULL_DETAIL = 50
|
| 5 |
+
CIRCULAR_ARC_TOLERANCE = 0.1
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
length_squared = lambda x: np.inner(x, x)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def approximate_bezier(control_points: np.ndarray) -> np.ndarray:
|
| 12 |
+
return approximate_b_spline(control_points)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def approximate_b_spline(control_points: np.ndarray, p: int = 0) -> np.ndarray:
|
| 16 |
+
output = []
|
| 17 |
+
n = len(control_points) - 1
|
| 18 |
+
|
| 19 |
+
if n < 0:
|
| 20 |
+
return output
|
| 21 |
+
|
| 22 |
+
to_flatten = []
|
| 23 |
+
free_buffers = []
|
| 24 |
+
|
| 25 |
+
points = control_points.copy()
|
| 26 |
+
|
| 27 |
+
if 0 < p < n:
|
| 28 |
+
for i in range(n - p):
|
| 29 |
+
sub_bezier = np.empty((p + 1, 2))
|
| 30 |
+
sub_bezier[0] = points[i]
|
| 31 |
+
|
| 32 |
+
for j in range(p - 1):
|
| 33 |
+
sub_bezier[j + 1] = points[i + 1]
|
| 34 |
+
|
| 35 |
+
for k in range(1, p - j):
|
| 36 |
+
l = np.min((k, n - p - i))
|
| 37 |
+
points[i + k] = (l * points[i + k] + points[i + k + 1]) / (l + 1)
|
| 38 |
+
|
| 39 |
+
sub_bezier[p] = points[i + 1]
|
| 40 |
+
to_flatten.append(sub_bezier)
|
| 41 |
+
|
| 42 |
+
to_flatten.append(points[(n - p) :])
|
| 43 |
+
to_flatten.reverse()
|
| 44 |
+
else:
|
| 45 |
+
p = n
|
| 46 |
+
to_flatten.append(points)
|
| 47 |
+
|
| 48 |
+
subdivision_buffer1 = np.empty([p + 1, 2])
|
| 49 |
+
subdivision_buffer2 = np.empty([p * 2 + 1, 2])
|
| 50 |
+
|
| 51 |
+
left_child = subdivision_buffer2
|
| 52 |
+
|
| 53 |
+
while len(to_flatten) > 0:
|
| 54 |
+
parent = to_flatten.pop()
|
| 55 |
+
|
| 56 |
+
if bezier_is_flat_enough(parent):
|
| 57 |
+
bezier_approximate(
|
| 58 |
+
parent,
|
| 59 |
+
output,
|
| 60 |
+
subdivision_buffer1,
|
| 61 |
+
subdivision_buffer2,
|
| 62 |
+
p + 1,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
free_buffers.append(parent)
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
right_child = (
|
| 69 |
+
free_buffers.pop() if len(free_buffers) > 0 else np.empty([p + 1, 2])
|
| 70 |
+
)
|
| 71 |
+
bezier_subdivide(parent, left_child, right_child, subdivision_buffer1, p + 1)
|
| 72 |
+
|
| 73 |
+
for i in range(p + 1):
|
| 74 |
+
parent[i] = left_child[i]
|
| 75 |
+
|
| 76 |
+
to_flatten.append(right_child)
|
| 77 |
+
to_flatten.append(parent)
|
| 78 |
+
|
| 79 |
+
output.append(control_points[n].copy())
|
| 80 |
+
return np.vstack(output)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def approximate_catmull(control_points: np.ndarray) -> list[np.ndarray]:
|
| 84 |
+
result = []
|
| 85 |
+
|
| 86 |
+
for i in range(len(control_points) - 1):
|
| 87 |
+
v1 = control_points[i - 1] if i > 0 else control_points[i]
|
| 88 |
+
v2 = control_points[i]
|
| 89 |
+
v3 = control_points[i + 1] if i < len(control_points) - 1 else v2 + v2 - v1
|
| 90 |
+
v4 = control_points[i + 2] if i < len(control_points) - 2 else v3 + v3 - v2
|
| 91 |
+
|
| 92 |
+
for c in range(CATMULL_DETAIL):
|
| 93 |
+
result.append(catmull_find_point(v1, v2, v3, v4, c / CATMULL_DETAIL))
|
| 94 |
+
result.append(catmull_find_point(v1, v2, v3, v4, (c + 1) / CATMULL_DETAIL))
|
| 95 |
+
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def approximate_circular_arc(control_points: np.ndarray) -> list[np.ndarray]:
|
| 100 |
+
a = control_points[0]
|
| 101 |
+
b = control_points[1]
|
| 102 |
+
c = control_points[2]
|
| 103 |
+
|
| 104 |
+
aSq = length_squared(b - c)
|
| 105 |
+
bSq = length_squared(a - c)
|
| 106 |
+
cSq = length_squared(a - b)
|
| 107 |
+
|
| 108 |
+
if np.isclose(aSq, 0) or np.isclose(bSq, 0) or np.isclose(cSq, 0):
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
s = aSq * (bSq + cSq - aSq)
|
| 112 |
+
t = bSq * (aSq + cSq - bSq)
|
| 113 |
+
u = cSq * (aSq + bSq - cSq)
|
| 114 |
+
|
| 115 |
+
sum = s + t + u
|
| 116 |
+
|
| 117 |
+
if np.isclose(sum, 0):
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
centre = (s * a + t * b + u * c) / sum
|
| 121 |
+
dA = a - centre
|
| 122 |
+
dC = c - centre
|
| 123 |
+
|
| 124 |
+
r = np.linalg.norm(dA)
|
| 125 |
+
|
| 126 |
+
theta_start = np.arctan2(dA[1], dA[0])
|
| 127 |
+
theta_end = np.arctan2(dC[1], dC[0])
|
| 128 |
+
|
| 129 |
+
while theta_end < theta_start:
|
| 130 |
+
theta_end += 2 * np.pi
|
| 131 |
+
|
| 132 |
+
direction = 1
|
| 133 |
+
theta_range = theta_range = theta_end - theta_start
|
| 134 |
+
|
| 135 |
+
ortho_ato_c = c - a
|
| 136 |
+
ortho_ato_c = np.array([ortho_ato_c[1], -ortho_ato_c[0]])
|
| 137 |
+
if np.dot(ortho_ato_c, b - a) < 0:
|
| 138 |
+
direction = -direction
|
| 139 |
+
theta_range = 2 * np.pi - theta_range
|
| 140 |
+
|
| 141 |
+
amount_points = (
|
| 142 |
+
2
|
| 143 |
+
if 2 * r <= CIRCULAR_ARC_TOLERANCE
|
| 144 |
+
else int(
|
| 145 |
+
max(
|
| 146 |
+
2,
|
| 147 |
+
np.ceil(theta_range / (2 * np.arccos(1 - CIRCULAR_ARC_TOLERANCE / r))),
|
| 148 |
+
),
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
output = []
|
| 153 |
+
|
| 154 |
+
for i in range(amount_points):
|
| 155 |
+
fract = i / (amount_points - 1)
|
| 156 |
+
theta = theta_start + direction * fract * theta_range
|
| 157 |
+
o = np.array([np.cos(theta), np.sin(theta)]) * r
|
| 158 |
+
output.append(centre + o)
|
| 159 |
+
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def approximate_linear(control_points: np.ndarray) -> list[np.ndarray]:
|
| 164 |
+
result = []
|
| 165 |
+
|
| 166 |
+
for c in control_points:
|
| 167 |
+
result.append(c.copy())
|
| 168 |
+
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def bezier_is_flat_enough(control_points: np.ndarray) -> bool:
|
| 173 |
+
for i in range(1, len(control_points) - 1):
|
| 174 |
+
p = control_points[i - 1] - 2 * control_points[i] + control_points[i + 1]
|
| 175 |
+
if length_squared(p) > BEZIER_TOLERANCE * BEZIER_TOLERANCE * 4:
|
| 176 |
+
return False
|
| 177 |
+
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def bezier_subdivide(
|
| 182 |
+
control_points: np.ndarray,
|
| 183 |
+
left: np.ndarray,
|
| 184 |
+
right: np.ndarray,
|
| 185 |
+
subdivision_buffer: np.ndarray,
|
| 186 |
+
count: int,
|
| 187 |
+
) -> None:
|
| 188 |
+
midpoints = subdivision_buffer
|
| 189 |
+
|
| 190 |
+
for i in range(count):
|
| 191 |
+
midpoints[i] = control_points[i]
|
| 192 |
+
|
| 193 |
+
for i in range(count):
|
| 194 |
+
left[i] = midpoints[0].copy()
|
| 195 |
+
right[count - i - 1] = midpoints[count - i - 1]
|
| 196 |
+
|
| 197 |
+
for j in range(count - i - 1):
|
| 198 |
+
midpoints[j] = (midpoints[j] + midpoints[j + 1]) / 2
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def bezier_approximate(
|
| 202 |
+
control_points: np.ndarray,
|
| 203 |
+
output: list[np.ndarray],
|
| 204 |
+
subdivision_buffer1: np.ndarray,
|
| 205 |
+
subdivision_buffer2: np.ndarray,
|
| 206 |
+
count: int,
|
| 207 |
+
) -> None:
|
| 208 |
+
left = subdivision_buffer2
|
| 209 |
+
right = subdivision_buffer1
|
| 210 |
+
|
| 211 |
+
bezier_subdivide(control_points, left, right, subdivision_buffer1, count)
|
| 212 |
+
|
| 213 |
+
for i in range(count - 1):
|
| 214 |
+
left[count + i] = right[i + 1]
|
| 215 |
+
|
| 216 |
+
output.append(control_points[0].copy())
|
| 217 |
+
|
| 218 |
+
for i in range(1, count - 1):
|
| 219 |
+
index = 2 * i
|
| 220 |
+
p = 0.25 * (left[index - 1] + 2 * left[index] + left[index + 1])
|
| 221 |
+
output.append(p.copy())
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def catmull_find_point(
|
| 225 |
+
vec1: np.ndarray,
|
| 226 |
+
vec2: np.ndarray,
|
| 227 |
+
vec3: np.ndarray,
|
| 228 |
+
vec4: np.ndarray,
|
| 229 |
+
t: float,
|
| 230 |
+
) -> np.ndarray:
|
| 231 |
+
t2 = t * t
|
| 232 |
+
t3 = t * t2
|
| 233 |
+
|
| 234 |
+
result = np.array(
|
| 235 |
+
[
|
| 236 |
+
0.5
|
| 237 |
+
* (
|
| 238 |
+
2 * vec2[0]
|
| 239 |
+
+ (-vec1[0] + vec3[0]) * t
|
| 240 |
+
+ (2 * vec1[0] - 5 * vec2[0] + 4 * vec3[0] - vec4[0]) * t2
|
| 241 |
+
+ (-vec1[0] + 3 * vec2[0] - 3 * vec3[0] + vec4[0]) * t3
|
| 242 |
+
),
|
| 243 |
+
0.5
|
| 244 |
+
* (
|
| 245 |
+
2 * vec2[1]
|
| 246 |
+
+ (-vec1[1] + vec3[1]) * t
|
| 247 |
+
+ (2 * vec1[1] - 5 * vec2[1] + 4 * vec3[1] - vec4[1]) * t2
|
| 248 |
+
+ (-vec1[1] + 3 * vec2[1] - 3 * vec3[1] + vec4[1]) * t3
|
| 249 |
+
),
|
| 250 |
+
],
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return result
|
osuT5/inference/pipeline.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from slider import Beatmap
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from omegaconf import DictConfig
|
| 11 |
+
|
| 12 |
+
from osuT5.dataset import OsuParser
|
| 13 |
+
from osuT5.dataset.data_utils import update_event_times
|
| 14 |
+
from osuT5.tokenizer import Event, EventType, Tokenizer
|
| 15 |
+
from osuT5.model import OsuT
|
| 16 |
+
|
| 17 |
+
MILISECONDS_PER_SECOND = 1000
|
| 18 |
+
MILISECONDS_PER_STEP = 10
|
| 19 |
+
|
| 20 |
+
def top_k_sampling(logits, k):
|
| 21 |
+
top_k_logits, top_k_indices = torch.topk(logits, k)
|
| 22 |
+
top_k_probs = F.softmax(top_k_logits, dim=-1)
|
| 23 |
+
sampled_index = torch.multinomial(top_k_probs, 1)
|
| 24 |
+
sampled_token = top_k_indices.gather(-1, sampled_index)
|
| 25 |
+
return sampled_token
|
| 26 |
+
|
| 27 |
+
def preprocess_event(event, frame_time):
|
| 28 |
+
if event.type == EventType.TIME_SHIFT:
|
| 29 |
+
event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
|
| 30 |
+
return event
|
| 31 |
+
|
| 32 |
+
class Pipeline(object):
|
| 33 |
+
def __init__(self, args: DictConfig, tokenizer: Tokenizer):
|
| 34 |
+
"""Model inference stage that processes sequences."""
|
| 35 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
self.tokenizer = tokenizer
|
| 37 |
+
self.tgt_seq_len = args.data.tgt_seq_len
|
| 38 |
+
self.frame_seq_len = args.data.src_seq_len - 1
|
| 39 |
+
self.frame_size = args.model.spectrogram.hop_length
|
| 40 |
+
self.sample_rate = args.model.spectrogram.sample_rate
|
| 41 |
+
self.samples_per_sequence = self.frame_seq_len * self.frame_size
|
| 42 |
+
self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride)
|
| 43 |
+
self.miliseconds_per_sequence = self.samples_per_sequence * MILISECONDS_PER_SECOND / self.sample_rate
|
| 44 |
+
self.miliseconds_per_stride = self.sequence_stride * MILISECONDS_PER_SECOND / self.sample_rate
|
| 45 |
+
self.beatmap_id = args.beatmap_id
|
| 46 |
+
self.difficulty = args.difficulty
|
| 47 |
+
self.center_pad_decoder = args.data.center_pad_decoder
|
| 48 |
+
self.special_token_len = args.data.special_token_len
|
| 49 |
+
self.diff_token_index = args.data.diff_token_index
|
| 50 |
+
self.style_token_index = args.data.style_token_index
|
| 51 |
+
self.max_pre_token_len = args.data.max_pre_token_len
|
| 52 |
+
self.add_pre_tokens = args.data.add_pre_tokens
|
| 53 |
+
self.add_gd_context = args.data.add_gd_context
|
| 54 |
+
self.bpm = args.bpm
|
| 55 |
+
self.offset = args.offset
|
| 56 |
+
self.total_duration_ms = args.total_duration_ms
|
| 57 |
+
|
| 58 |
+
print(f"Configuration: {args}")
|
| 59 |
+
|
| 60 |
+
if self.add_gd_context:
|
| 61 |
+
other_beatmap_path = Path(args.other_beatmap_path)
|
| 62 |
+
|
| 63 |
+
if not other_beatmap_path.is_file():
|
| 64 |
+
raise FileNotFoundError(f"Beatmap file {other_beatmap_path} not found.")
|
| 65 |
+
|
| 66 |
+
other_beatmap = Beatmap.from_path(other_beatmap_path)
|
| 67 |
+
self.other_beatmap_id = other_beatmap.beatmap_id
|
| 68 |
+
self.other_difficulty = float(other_beatmap.stars())
|
| 69 |
+
parser = OsuParser(tokenizer)
|
| 70 |
+
self.other_events = parser.parse(other_beatmap)
|
| 71 |
+
self.other_events, self.other_event_times = self._prepare_events(self.other_events)
|
| 72 |
+
|
| 73 |
+
def _calculate_time_shifts(self, bpm: float, duration_ms: float, tick_rate: int, offset: float = 0) -> list[float]:
|
| 74 |
+
"""Calculate EventType.TIME_SHIFT events based on song's BPM and tick rate."""
|
| 75 |
+
events = []
|
| 76 |
+
ms_per_beat = 60000 / bpm # 60000 ms per minute
|
| 77 |
+
ms_per_tick = ms_per_beat / tick_rate
|
| 78 |
+
num_ticks = int(duration_ms // ms_per_tick)
|
| 79 |
+
|
| 80 |
+
for i in range(num_ticks):
|
| 81 |
+
events.append(float(int(i * ms_per_tick + offset)) )
|
| 82 |
+
|
| 83 |
+
return events
|
| 84 |
+
|
| 85 |
+
def generate_events(self, model, frames, tokens, encoder_outputs, beatmap_idx, total_steps):
|
| 86 |
+
temperature = 0.9
|
| 87 |
+
k = 10 # top-k sampling
|
| 88 |
+
|
| 89 |
+
for _ in range(total_steps):
|
| 90 |
+
out = model.forward(
|
| 91 |
+
frames=frames,
|
| 92 |
+
decoder_input_ids=tokens,
|
| 93 |
+
decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
|
| 94 |
+
encoder_outputs=encoder_outputs,
|
| 95 |
+
beatmap_idx=beatmap_idx,
|
| 96 |
+
)
|
| 97 |
+
encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
|
| 98 |
+
logits = out.logits
|
| 99 |
+
logits = logits[:, -1, :] / temperature
|
| 100 |
+
logits = self._filter(logits, 0.9)
|
| 101 |
+
probabilities = F.softmax(logits, dim=-1)
|
| 102 |
+
next_tokens = top_k_sampling(probabilities, k)
|
| 103 |
+
|
| 104 |
+
tokens = torch.cat([tokens, next_tokens], dim=-1)
|
| 105 |
+
|
| 106 |
+
eos_in_sentence = next_tokens == self.tokenizer.eos_id
|
| 107 |
+
if eos_in_sentence.all():
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
return tokens
|
| 111 |
+
|
| 112 |
+
def generate(self, model: OsuT, sequences: torch.Tensor, top_k: int = 50) -> list[Event]:
|
| 113 |
+
"""
|
| 114 |
+
Generate a list of Event object lists and their timestamps given source sequences.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
model: Trained model to use for inference.
|
| 118 |
+
sequences: A list of batched source sequences.
|
| 119 |
+
top_k: Number of top tokens to use for top-k sampling.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
events: List of Event object lists.
|
| 123 |
+
event_times: Corresponding event times of Event object lists in milliseconds.
|
| 124 |
+
"""
|
| 125 |
+
events = []
|
| 126 |
+
event_times = []
|
| 127 |
+
temperature = 0.95
|
| 128 |
+
|
| 129 |
+
idx_dict = self.tokenizer.beatmap_idx
|
| 130 |
+
beatmap_idx = torch.tensor([idx_dict.get(self.beatmap_id, 6666)], dtype=torch.long, device=self.device)
|
| 131 |
+
style_token = self.tokenizer.encode_style(self.beatmap_id) if self.beatmap_id in idx_dict else self.tokenizer.style_unk
|
| 132 |
+
diff_token = self.tokenizer.encode_diff(self.difficulty) if self.difficulty != -1 else self.tokenizer.diff_unk
|
| 133 |
+
|
| 134 |
+
special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
|
| 135 |
+
special_tokens[:, self.diff_token_index] = diff_token
|
| 136 |
+
special_tokens[:, self.style_token_index] = style_token
|
| 137 |
+
|
| 138 |
+
if self.add_gd_context:
|
| 139 |
+
other_style_token = self.tokenizer.encode_style(self.other_beatmap_id) if self.other_beatmap_id in idx_dict else self.tokenizer.style_unk
|
| 140 |
+
other_special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
|
| 141 |
+
other_special_tokens[:, self.diff_token_index] = self.tokenizer.encode_diff(self.other_difficulty)
|
| 142 |
+
other_special_tokens[:, self.style_token_index] = other_style_token
|
| 143 |
+
else:
|
| 144 |
+
other_special_tokens = torch.empty((1, 0), dtype=torch.long, device=self.device)
|
| 145 |
+
|
| 146 |
+
for sequence_index, frames in enumerate(tqdm(sequences)):
|
| 147 |
+
# Get tokens of previous frame
|
| 148 |
+
frame_time = sequence_index * self.miliseconds_per_stride
|
| 149 |
+
prev_events = self._get_events_time_range(
|
| 150 |
+
events, event_times, frame_time - self.miliseconds_per_sequence, frame_time) if self.add_pre_tokens else []
|
| 151 |
+
post_events = self._get_events_time_range(
|
| 152 |
+
events, event_times, frame_time, frame_time + self.miliseconds_per_sequence)
|
| 153 |
+
|
| 154 |
+
prev_tokens = self._encode(prev_events, frame_time)
|
| 155 |
+
post_tokens = self._encode(post_events, frame_time)
|
| 156 |
+
post_token_length = post_tokens.shape[1]
|
| 157 |
+
|
| 158 |
+
if 0 <= self.max_pre_token_len < prev_tokens.shape[1]:
|
| 159 |
+
prev_tokens = prev_tokens[:, -self.max_pre_token_len:]
|
| 160 |
+
|
| 161 |
+
# Get prefix tokens
|
| 162 |
+
prefix = torch.cat([special_tokens, prev_tokens], dim=-1)
|
| 163 |
+
if self.center_pad_decoder:
|
| 164 |
+
prefix = F.pad(prefix, (self.tgt_seq_len // 2 - prefix.shape[1], 0), value=self.tokenizer.pad_id)
|
| 165 |
+
prefix_length = prefix.shape[1]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
max_retries = 5
|
| 169 |
+
attempt = 0
|
| 170 |
+
result = []
|
| 171 |
+
|
| 172 |
+
while attempt < max_retries and not result:
|
| 173 |
+
attempt += 1
|
| 174 |
+
try:
|
| 175 |
+
# Reset tokens
|
| 176 |
+
tokens = torch.tensor([[self.tokenizer.sos_id]], dtype=torch.long, device=self.device)
|
| 177 |
+
tokens = torch.cat([prefix, tokens, post_tokens], dim=-1)
|
| 178 |
+
|
| 179 |
+
# Ensure frames are properly reset for each retry
|
| 180 |
+
retry_frames = frames.clone().to(self.device).unsqueeze(0)
|
| 181 |
+
encoder_outputs = None
|
| 182 |
+
|
| 183 |
+
while tokens.shape[-1] < self.tgt_seq_len:
|
| 184 |
+
out = model.forward(
|
| 185 |
+
frames=retry_frames,
|
| 186 |
+
decoder_input_ids=tokens,
|
| 187 |
+
decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
|
| 188 |
+
encoder_outputs=encoder_outputs,
|
| 189 |
+
#beatmap_idx=beatmap_idx,
|
| 190 |
+
)
|
| 191 |
+
encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
|
| 192 |
+
|
| 193 |
+
logits = out.logits[:, -1, :]
|
| 194 |
+
logits = logits / temperature
|
| 195 |
+
logits = self._filter(logits, top_p=0.9, top_k=60)
|
| 196 |
+
probabilities = F.softmax(logits, dim=-1)
|
| 197 |
+
next_tokens = torch.multinomial(probabilities, 1)
|
| 198 |
+
|
| 199 |
+
tokens = torch.cat([tokens, next_tokens], dim=-1)
|
| 200 |
+
|
| 201 |
+
eos_in_sentence = next_tokens == self.tokenizer.eos_id
|
| 202 |
+
if eos_in_sentence.all():
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
predicted_tokens = tokens[:, prefix_length + 1 + post_token_length:]
|
| 206 |
+
result = self._decode(predicted_tokens[0], frame_time)
|
| 207 |
+
|
| 208 |
+
# if no new combo in result, retry;
|
| 209 |
+
if len(result) > 10 and not any(event.type == EventType.NEW_COMBO for event in result):
|
| 210 |
+
#print("No new combo in result; retrying...")
|
| 211 |
+
result = []
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
#print(f"Attempt {attempt} encountered an error: {e}")
|
| 216 |
+
result = [] # Ensure result is empty to trigger retry
|
| 217 |
+
|
| 218 |
+
events += result
|
| 219 |
+
|
| 220 |
+
self._update_event_times(events, event_times, frame_time)
|
| 221 |
+
|
| 222 |
+
return events
|
| 223 |
+
|
| 224 |
+
def _prepare_events(self, events: list[Event]) -> tuple[list[Event], list[float]]:
|
| 225 |
+
"""Pre-process raw list of events for inference. Calculates event times and removes redundant time shifts."""
|
| 226 |
+
ct = 0
|
| 227 |
+
event_times = []
|
| 228 |
+
for event in events:
|
| 229 |
+
if event.type == EventType.TIME_SHIFT:
|
| 230 |
+
ct = event.value
|
| 231 |
+
event_times.append(ct)
|
| 232 |
+
|
| 233 |
+
# Loop through the events in reverse to remove any time shifts that occur before anchor events
|
| 234 |
+
delete_next_time_shift = False
|
| 235 |
+
for i in range(len(events) - 1, -1, -1):
|
| 236 |
+
if events[i].type == EventType.TIME_SHIFT and delete_next_time_shift:
|
| 237 |
+
delete_next_time_shift = False
|
| 238 |
+
del events[i]
|
| 239 |
+
del event_times[i]
|
| 240 |
+
continue
|
| 241 |
+
elif events[i].type in [EventType.BEZIER_ANCHOR, EventType.PERFECT_ANCHOR, EventType.CATMULL_ANCHOR,
|
| 242 |
+
EventType.RED_ANCHOR]:
|
| 243 |
+
delete_next_time_shift = True
|
| 244 |
+
|
| 245 |
+
# duplicate events 3 times
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
return events, event_times
|
| 249 |
+
|
| 250 |
+
def _get_events_time_range(self, events: list[Event], event_times: list[float], start_time: float, end_time: float):
|
| 251 |
+
# Look from the end of the list
|
| 252 |
+
s = 0
|
| 253 |
+
for i in range(len(event_times) - 1, -1, -1):
|
| 254 |
+
if event_times[i] < start_time:
|
| 255 |
+
s = i + 1
|
| 256 |
+
break
|
| 257 |
+
e = 0
|
| 258 |
+
for i in range(len(event_times) - 1, -1, -1):
|
| 259 |
+
if event_times[i] < end_time:
|
| 260 |
+
e = i + 1
|
| 261 |
+
break
|
| 262 |
+
return events[s:e]
|
| 263 |
+
|
| 264 |
+
def _update_event_times(self, events: list[Event], event_times: list[float], frame_time: float):
|
| 265 |
+
update_event_times(events, event_times, frame_time + self.miliseconds_per_sequence)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _encode(self, events: list[Event], frame_time: float) -> torch.Tensor:
|
| 269 |
+
try:
|
| 270 |
+
|
| 271 |
+
tokens = torch.empty((1, len(events)), dtype=torch.long)
|
| 272 |
+
for i, event in enumerate(events):
|
| 273 |
+
if event.type == EventType.TIME_SHIFT:
|
| 274 |
+
event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
|
| 275 |
+
tokens[0, i] = self.tokenizer.encode(event)
|
| 276 |
+
return tokens.to(self.device)
|
| 277 |
+
except Exception as e:
|
| 278 |
+
#print(f"Error encoding events: {events}")
|
| 279 |
+
#print(e)
|
| 280 |
+
return torch.empty((1, 0), dtype=torch.long, device=self.device)
|
| 281 |
+
def _decode(self, tokens: torch.Tensor, frame_time: float) -> list[Event]:
|
| 282 |
+
"""Converts a list of tokens into Event objects and converts to absolute time values.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
tokens: List of tokens.
|
| 286 |
+
frame time: Start time of current source sequence.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
events: List of Event objects.
|
| 290 |
+
"""
|
| 291 |
+
events = []
|
| 292 |
+
for token in tokens:
|
| 293 |
+
if token == self.tokenizer.eos_id:
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
event = self.tokenizer.decode(token.item())
|
| 298 |
+
except:
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
if event.type == EventType.TIME_SHIFT:
|
| 302 |
+
event.value = frame_time + event.value * MILISECONDS_PER_STEP
|
| 303 |
+
|
| 304 |
+
events.append(event)
|
| 305 |
+
|
| 306 |
+
return events
|
| 307 |
+
|
| 308 |
+
def _filter(self, logits: torch.Tensor, top_p: float = 0.75, top_k: int = 1, filter_value: float = -float("Inf")) -> torch.Tensor:
|
| 309 |
+
"""Filter a distribution of logits using nucleus (top-p) and/or top-k filtering.
|
| 310 |
+
"""
|
| 311 |
+
logits = top_k_logits(logits, top_k) if top_k > 0 else logits
|
| 312 |
+
|
| 313 |
+
if 0.0 < top_p < 1.0:
|
| 314 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 315 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 316 |
+
|
| 317 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 318 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 319 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 320 |
+
|
| 321 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 322 |
+
logits[indices_to_remove] = filter_value
|
| 323 |
+
|
| 324 |
+
return logits
|
| 325 |
+
def top_k_logits(logits, k):
|
| 326 |
+
"""
|
| 327 |
+
Keep only the top-k tokens with highest probabilities.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
logits: Logits distribution of shape (batch size, vocabulary size).
|
| 331 |
+
k: Number of top tokens to keep.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
logits with non-top-k elements set to negative infinity.
|
| 335 |
+
"""
|
| 336 |
+
values, indices = torch.topk(logits, k)
|
| 337 |
+
min_values = values[:, -1].unsqueeze(-1).expand_as(logits)
|
| 338 |
+
return torch.where(logits < min_values, torch.full_like(logits, float("-Inf")), logits)
|
osuT5/inference/postprocessor.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import os
|
| 5 |
+
import pathlib
|
| 6 |
+
import uuid
|
| 7 |
+
from string import Template
|
| 8 |
+
import zipfile
|
| 9 |
+
import numpy as np
|
| 10 |
+
from omegaconf import DictConfig
|
| 11 |
+
import time as t
|
| 12 |
+
from osuT5.inference.slider_path import SliderPath
|
| 13 |
+
from osuT5.tokenizer import Event, EventType
|
| 14 |
+
|
| 15 |
+
OSU_FILE_EXTENSION = ".osu"
|
| 16 |
+
OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu")
|
| 17 |
+
STEPS_PER_MILLISECOND = 0.1
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclasses.dataclass
|
| 21 |
+
class BeatmapConfig:
|
| 22 |
+
# General
|
| 23 |
+
audio_filename: str = ""
|
| 24 |
+
|
| 25 |
+
# Metadata
|
| 26 |
+
title: str = ""
|
| 27 |
+
title_unicode: str = ""
|
| 28 |
+
artist: str = ""
|
| 29 |
+
artist_unicode: str = ""
|
| 30 |
+
creator: str = ""
|
| 31 |
+
version: str = ""
|
| 32 |
+
|
| 33 |
+
# Difficulty
|
| 34 |
+
hp_drain_rate: float = 5
|
| 35 |
+
circle_size: float = 4
|
| 36 |
+
overall_difficulty: float = 8
|
| 37 |
+
approach_rate: float = 9
|
| 38 |
+
slider_multiplier: float = 1.8
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def calculate_coordinates(last_pos, dist, num_samples, playfield_size):
|
| 42 |
+
# Generate a set of angles
|
| 43 |
+
angles = np.linspace(0, 2*np.pi, num_samples)
|
| 44 |
+
|
| 45 |
+
# Calculate the x and y coordinates for each angle
|
| 46 |
+
x_coords = last_pos[0] + dist * np.cos(angles)
|
| 47 |
+
y_coords = last_pos[1] + dist * np.sin(angles)
|
| 48 |
+
|
| 49 |
+
# Combine the x and y coordinates into a list of tuples
|
| 50 |
+
coordinates = list(zip(x_coords, y_coords))
|
| 51 |
+
|
| 52 |
+
# Filter out coordinates that are outside the playfield
|
| 53 |
+
coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]]
|
| 54 |
+
|
| 55 |
+
if len(coordinates) == 0:
|
| 56 |
+
return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)]
|
| 57 |
+
|
| 58 |
+
return coordinates
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray:
|
| 62 |
+
eps = 1e-4
|
| 63 |
+
lr = 1
|
| 64 |
+
t = 1
|
| 65 |
+
for i in range(100):
|
| 66 |
+
grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm(
|
| 67 |
+
slider_path.position_at(t - eps) - pos,
|
| 68 |
+
)
|
| 69 |
+
t -= lr * grad
|
| 70 |
+
|
| 71 |
+
if grad == 0 or t < 0 or t > 1:
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
return np.clip(t, 0, 1)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def quantize_to_beat(time, bpm, offset):
|
| 79 |
+
"""Quantize a given time to the nearest beat based on the BPM and offset."""
|
| 80 |
+
# tick rate is 1/4
|
| 81 |
+
#tick_rate = 0.25
|
| 82 |
+
# tick rate is 1/8
|
| 83 |
+
# tick_rate = 0.125
|
| 84 |
+
# tick rate is 1/2
|
| 85 |
+
#tick_rate = 0.5
|
| 86 |
+
tick_rate = 0.5
|
| 87 |
+
beats_per_minute = bpm
|
| 88 |
+
beats_per_second = beats_per_minute / 60.0
|
| 89 |
+
milliseconds_per_beat = 1000 / beats_per_second
|
| 90 |
+
quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
|
| 91 |
+
return quantized_time
|
| 92 |
+
|
| 93 |
+
def quantize_to_beat_again(time, bpm, offset):
|
| 94 |
+
"""Quantize a given time to the nearest beat based on the BPM and offset."""
|
| 95 |
+
# tick rate is 1/4
|
| 96 |
+
#tick_rate = 0.25
|
| 97 |
+
# tick rate is 1/8
|
| 98 |
+
# tick_rate = 0.125
|
| 99 |
+
# tick rate is 1/2
|
| 100 |
+
#tick_rate = 0.5
|
| 101 |
+
tick_rate = 0.25
|
| 102 |
+
beats_per_minute = bpm
|
| 103 |
+
beats_per_second = beats_per_minute / 60.0
|
| 104 |
+
milliseconds_per_beat = 1000 / beats_per_second
|
| 105 |
+
quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
|
| 106 |
+
return quantized_time
|
| 107 |
+
|
| 108 |
+
def move_to_next_tick(time, bpm):
|
| 109 |
+
"""Move to the next tick based on the BPM and offset."""
|
| 110 |
+
tick_rate = 0.25
|
| 111 |
+
beats_per_minute = bpm
|
| 112 |
+
beats_per_second = beats_per_minute / 60.0
|
| 113 |
+
milliseconds_per_beat = 1000 / beats_per_second
|
| 114 |
+
quantized_time = time + milliseconds_per_beat * tick_rate
|
| 115 |
+
return quantized_time
|
| 116 |
+
|
| 117 |
+
def move_to_prev_tick(time, bpm):
|
| 118 |
+
"""Move to the next tick based on the BPM and offset."""
|
| 119 |
+
tick_rate = 0.25
|
| 120 |
+
beats_per_minute = bpm
|
| 121 |
+
beats_per_second = beats_per_minute / 60.0
|
| 122 |
+
milliseconds_per_beat = 1000 / beats_per_second
|
| 123 |
+
quantized_time = time - milliseconds_per_beat * tick_rate
|
| 124 |
+
return quantized_time
|
| 125 |
+
|
| 126 |
+
def adjust_hit_objects(hit_objects, bpm, offset):
|
| 127 |
+
"""Adjust the timing of hit objects to align with beats based on BPM and offset."""
|
| 128 |
+
adjusted_hit_objects = []
|
| 129 |
+
adjusted_times = []
|
| 130 |
+
to_be_adjusted = []
|
| 131 |
+
for hit_object in hit_objects:
|
| 132 |
+
hit_type = hit_object.type
|
| 133 |
+
if hit_type == EventType.TIME_SHIFT:
|
| 134 |
+
time = quantize_to_beat(hit_object.value, bpm, offset)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END):
|
| 138 |
+
time = move_to_next_tick(time, bpm)
|
| 139 |
+
adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
|
| 140 |
+
adjusted_times.append(int(time))
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
|
| 144 |
+
adjusted_times.append(int(time))
|
| 145 |
+
else:
|
| 146 |
+
adjusted_hit_objects.append(hit_object)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
return adjusted_hit_objects
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class Postprocessor(object):
|
| 155 |
+
def __init__(self, args: DictConfig):
|
| 156 |
+
"""Postprocessing stage that converts a list of Event objects to a beatmap file."""
|
| 157 |
+
self.curve_type_shorthand = {
|
| 158 |
+
"B": "Bezier",
|
| 159 |
+
"P": "PerfectCurve",
|
| 160 |
+
"C": "Catmull",
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
self.output_path = args.output_path
|
| 164 |
+
self.audio_path = args.audio_path
|
| 165 |
+
self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0]
|
| 166 |
+
self.beatmap_config = BeatmapConfig(
|
| 167 |
+
title=str(f"{self.audio_filename} ({args.title})"),
|
| 168 |
+
artist=str(args.artist),
|
| 169 |
+
title_unicode=str(args.title),
|
| 170 |
+
artist_unicode=str(args.artist),
|
| 171 |
+
audio_filename=pathlib.Path(args.audio_path).name,
|
| 172 |
+
slider_multiplier=float(args.slider_multiplier),
|
| 173 |
+
creator=str(args.creator),
|
| 174 |
+
version=str(args.version),
|
| 175 |
+
)
|
| 176 |
+
self.offset = args.offset
|
| 177 |
+
self.beat_length = 60000 / args.bpm
|
| 178 |
+
self.slider_multiplier = self.beatmap_config.slider_multiplier
|
| 179 |
+
self.bpm = args.bpm
|
| 180 |
+
self.resnap_objects = args.resnap_objects
|
| 181 |
+
|
| 182 |
+
def generate(self, generated_positions: list[Event]):
|
| 183 |
+
"""Generate a beatmap file.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
events: List of Event objects.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
None. An .osu file will be generated.
|
| 190 |
+
"""
|
| 191 |
+
processed_events = []
|
| 192 |
+
|
| 193 |
+
for events in generated_positions:
|
| 194 |
+
# adjust hit objects to align with 1/4 beats
|
| 195 |
+
if self.resnap_objects:
|
| 196 |
+
events = adjust_hit_objects(events, self.bpm, self.offset)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
hit_object_strings = []
|
| 200 |
+
time = 0
|
| 201 |
+
dist = 0
|
| 202 |
+
x = 256
|
| 203 |
+
y = 192
|
| 204 |
+
has_pos = False
|
| 205 |
+
new_combo = 0
|
| 206 |
+
ho_info = []
|
| 207 |
+
anchor_info = []
|
| 208 |
+
|
| 209 |
+
timing_point_strings = [
|
| 210 |
+
f"{self.offset},{self.beat_length},4,2,0,100,1,0"
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
for event in events:
|
| 214 |
+
hit_type = event.type
|
| 215 |
+
|
| 216 |
+
if hit_type == EventType.TIME_SHIFT:
|
| 217 |
+
time = event.value
|
| 218 |
+
continue
|
| 219 |
+
elif hit_type == EventType.DISTANCE:
|
| 220 |
+
# Find a point which is dist away from the last point but still within the playfield
|
| 221 |
+
dist = event.value
|
| 222 |
+
coordinates = calculate_coordinates((x, y), dist, 500, (512, 384))
|
| 223 |
+
pos = coordinates[np.random.randint(len(coordinates))]
|
| 224 |
+
x, y = pos
|
| 225 |
+
continue
|
| 226 |
+
elif hit_type == EventType.POS_X:
|
| 227 |
+
x = event.value
|
| 228 |
+
has_pos = True
|
| 229 |
+
continue
|
| 230 |
+
elif hit_type == EventType.POS_Y:
|
| 231 |
+
y = event.value
|
| 232 |
+
has_pos = True
|
| 233 |
+
continue
|
| 234 |
+
elif hit_type == EventType.NEW_COMBO:
|
| 235 |
+
new_combo = 4
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
if hit_type == EventType.CIRCLE:
|
| 239 |
+
hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0")
|
| 240 |
+
ho_info = []
|
| 241 |
+
|
| 242 |
+
elif hit_type == EventType.SPINNER:
|
| 243 |
+
ho_info = [time, new_combo]
|
| 244 |
+
|
| 245 |
+
elif hit_type == EventType.SPINNER_END and len(ho_info) == 2:
|
| 246 |
+
hit_object_strings.append(
|
| 247 |
+
f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}"
|
| 248 |
+
)
|
| 249 |
+
ho_info = []
|
| 250 |
+
|
| 251 |
+
elif hit_type == EventType.SLIDER_HEAD:
|
| 252 |
+
ho_info = [x, y, time, new_combo]
|
| 253 |
+
anchor_info = []
|
| 254 |
+
|
| 255 |
+
elif hit_type == EventType.BEZIER_ANCHOR:
|
| 256 |
+
anchor_info.append(('B', x, y))
|
| 257 |
+
|
| 258 |
+
elif hit_type == EventType.PERFECT_ANCHOR:
|
| 259 |
+
anchor_info.append(('P', x, y))
|
| 260 |
+
|
| 261 |
+
elif hit_type == EventType.CATMULL_ANCHOR:
|
| 262 |
+
anchor_info.append(('C', x, y))
|
| 263 |
+
|
| 264 |
+
elif hit_type == EventType.RED_ANCHOR:
|
| 265 |
+
anchor_info.append(('B', x, y))
|
| 266 |
+
anchor_info.append(('B', x, y))
|
| 267 |
+
|
| 268 |
+
elif hit_type == EventType.LAST_ANCHOR:
|
| 269 |
+
ho_info.append(time)
|
| 270 |
+
anchor_info.append(('B', x, y))
|
| 271 |
+
|
| 272 |
+
elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0:
|
| 273 |
+
curve_type = anchor_info[0][0]
|
| 274 |
+
span_duration = ho_info[4] - ho_info[2]
|
| 275 |
+
total_duration = time - ho_info[2]
|
| 276 |
+
|
| 277 |
+
if total_duration == 0 or span_duration == 0:
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
slides = max(int(round(total_duration / span_duration)), 1)
|
| 281 |
+
control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info)
|
| 282 |
+
slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float))
|
| 283 |
+
length = slider_path.get_distance()
|
| 284 |
+
|
| 285 |
+
req_length = length * position_to_progress(
|
| 286 |
+
slider_path,
|
| 287 |
+
np.array((x, y)),
|
| 288 |
+
) if has_pos else length - dist
|
| 289 |
+
|
| 290 |
+
if req_length < 1e-4:
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
hit_object_strings.append(
|
| 294 |
+
f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000
|
| 298 |
+
timing_point_strings.append(
|
| 299 |
+
f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
new_combo = 0
|
| 303 |
+
|
| 304 |
+
# Write .osu file
|
| 305 |
+
with open(OSU_TEMPLATE_PATH, "r") as tf:
|
| 306 |
+
template = Template(tf.read())
|
| 307 |
+
hit_objects = {"hit_objects": "\n".join(hit_object_strings)}
|
| 308 |
+
timing_points = {"timing_points": "\n".join(timing_point_strings)}
|
| 309 |
+
beatmap_config = dataclasses.asdict(self.beatmap_config)
|
| 310 |
+
result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points})
|
| 311 |
+
processed_events.append(result)
|
| 312 |
+
|
| 313 |
+
osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz")
|
| 314 |
+
with zipfile.ZipFile(osz_path, "w") as z:
|
| 315 |
+
for i, event in enumerate(processed_events):
|
| 316 |
+
osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}")
|
| 317 |
+
with open(osu_path, "w") as osu_file:
|
| 318 |
+
osu_file.write(event)
|
| 319 |
+
z.write(osu_path, os.path.basename(osu_path))
|
| 320 |
+
z.write(self.audio_path, os.path.basename(self.audio_path))
|
| 321 |
+
print(f"Mapset saved {osz_path}")
|
| 322 |
+
z.close()
|
osuT5/inference/preprocessor.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
from omegaconf import DictConfig
|
| 9 |
+
|
| 10 |
+
from osuT5.dataset.data_utils import load_audio_file
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Preprocessor(object):
|
| 14 |
+
def __init__(self, args: DictConfig):
|
| 15 |
+
"""Preprocess audio data into sequences."""
|
| 16 |
+
self.frame_seq_len = args.data.src_seq_len - 1
|
| 17 |
+
self.frame_size = args.data.hop_length
|
| 18 |
+
self.sample_rate = args.data.sample_rate
|
| 19 |
+
self.samples_per_sequence = self.frame_seq_len * self.frame_size
|
| 20 |
+
self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride)
|
| 21 |
+
|
| 22 |
+
def load(self, path: Path) -> npt.ArrayLike:
|
| 23 |
+
"""Load an audio file as audio frames. Convert stereo to mono, normalize.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
path: Path to audio file.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
samples: Audio time-series.
|
| 30 |
+
"""
|
| 31 |
+
return load_audio_file(path, self.sample_rate)
|
| 32 |
+
|
| 33 |
+
def segment(self, samples: npt.ArrayLike) -> torch.Tensor:
|
| 34 |
+
"""Segment audio samples into sequences. Sequences are flattened frames.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
samples: Audio time-series.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
sequences: A list of sequences of shape (batch size, samples per sequence).
|
| 41 |
+
"""
|
| 42 |
+
samples = np.pad(
|
| 43 |
+
samples,
|
| 44 |
+
[0, self.sequence_stride - (len(samples) - self.samples_per_sequence) % self.sequence_stride],
|
| 45 |
+
)
|
| 46 |
+
sequences = self.window(samples, self.samples_per_sequence, self.sequence_stride)
|
| 47 |
+
sequences = torch.from_numpy(sequences).to(torch.float32)
|
| 48 |
+
return sequences
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def window(a, w, o, copy=False):
|
| 52 |
+
sh = (a.size - w + 1, w)
|
| 53 |
+
st = a.strides * 2
|
| 54 |
+
view = np.lib.stride_tricks.as_strided(a, strides=st, shape=sh)[0::o]
|
| 55 |
+
if copy:
|
| 56 |
+
return view.copy()
|
| 57 |
+
else:
|
| 58 |
+
return view
|
osuT5/inference/slider_path.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from numpy.linalg import norm
|
| 5 |
+
|
| 6 |
+
import osuT5.inference.path_approximator as path_approximator
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def binary_search(array, target):
|
| 10 |
+
lower = 0
|
| 11 |
+
upper = len(array)
|
| 12 |
+
while lower < upper: # use < instead of <=
|
| 13 |
+
x = lower + (upper - lower) // 2
|
| 14 |
+
val = array[x]
|
| 15 |
+
if target == val:
|
| 16 |
+
return x
|
| 17 |
+
elif target > val:
|
| 18 |
+
if lower == x: # these two are the actual lines
|
| 19 |
+
break # you're looking for
|
| 20 |
+
lower = x
|
| 21 |
+
elif target < val:
|
| 22 |
+
upper = x
|
| 23 |
+
return ~upper
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SliderPath:
|
| 27 |
+
__slots__ = (
|
| 28 |
+
"control_points",
|
| 29 |
+
"path_type",
|
| 30 |
+
"expected_distance",
|
| 31 |
+
"calculated_path",
|
| 32 |
+
"cumulative_length",
|
| 33 |
+
"is_initialised",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
path_type: str,
|
| 39 |
+
control_points: np.array,
|
| 40 |
+
expected_distance: float | None = None,
|
| 41 |
+
) -> None:
|
| 42 |
+
self.control_points = control_points
|
| 43 |
+
self.path_type = path_type
|
| 44 |
+
self.expected_distance = expected_distance
|
| 45 |
+
|
| 46 |
+
self.calculated_path = None
|
| 47 |
+
self.cumulative_length = None
|
| 48 |
+
|
| 49 |
+
self.is_initialised = None
|
| 50 |
+
|
| 51 |
+
self.ensure_initialised()
|
| 52 |
+
|
| 53 |
+
def get_control_points(self) -> np.array:
|
| 54 |
+
self.ensure_initialised()
|
| 55 |
+
return self.control_points
|
| 56 |
+
|
| 57 |
+
def get_distance(self) -> float:
|
| 58 |
+
self.ensure_initialised()
|
| 59 |
+
return 0 if len(self.cumulative_length) == 0 else self.cumulative_length[-1]
|
| 60 |
+
|
| 61 |
+
def get_path_to_progress(self, path, p0, p1) -> None:
|
| 62 |
+
self.ensure_initialised()
|
| 63 |
+
|
| 64 |
+
d0 = self.progress_to_distance(p0)
|
| 65 |
+
d1 = self.progress_to_distance(p1)
|
| 66 |
+
|
| 67 |
+
path.clear()
|
| 68 |
+
|
| 69 |
+
i = 0
|
| 70 |
+
while i < len(self.calculated_path) and self.cumulative_length[i] < d0:
|
| 71 |
+
i += 1
|
| 72 |
+
|
| 73 |
+
path.append(self.interpolate_vertices(i, d0))
|
| 74 |
+
|
| 75 |
+
while i < len(self.calculated_path) and self.cumulative_length[i] < d1:
|
| 76 |
+
path.append(self.calculated_path[i])
|
| 77 |
+
i += 1
|
| 78 |
+
|
| 79 |
+
path.append(self.interpolate_vertices(i, d1))
|
| 80 |
+
|
| 81 |
+
def position_at(self, progress) -> np.array:
|
| 82 |
+
self.ensure_initialised()
|
| 83 |
+
|
| 84 |
+
d = self.progress_to_distance(progress)
|
| 85 |
+
return self.interpolate_vertices(self.index_of_distance(d), d)
|
| 86 |
+
|
| 87 |
+
def ensure_initialised(self) -> None:
|
| 88 |
+
if self.is_initialised:
|
| 89 |
+
return
|
| 90 |
+
self.is_initialised = True
|
| 91 |
+
|
| 92 |
+
self.control_points = [] if self.control_points is None else self.control_points
|
| 93 |
+
self.calculated_path = []
|
| 94 |
+
self.cumulative_length = []
|
| 95 |
+
|
| 96 |
+
self.calculate_path()
|
| 97 |
+
self.calculate_cumulative_length()
|
| 98 |
+
|
| 99 |
+
def calculate_subpath(self, sub_control_points) -> list:
|
| 100 |
+
if self.path_type == "Linear":
|
| 101 |
+
return path_approximator.approximate_linear(sub_control_points)
|
| 102 |
+
elif self.path_type == "PerfectCurve":
|
| 103 |
+
if len(self.get_control_points()) != 3 or len(sub_control_points) != 3:
|
| 104 |
+
return path_approximator.approximate_bezier(sub_control_points)
|
| 105 |
+
|
| 106 |
+
subpath = path_approximator.approximate_circular_arc(sub_control_points)
|
| 107 |
+
|
| 108 |
+
if len(subpath) == 0:
|
| 109 |
+
return path_approximator.approximate_bezier(sub_control_points)
|
| 110 |
+
|
| 111 |
+
return subpath
|
| 112 |
+
elif self.path_type == "Catmull":
|
| 113 |
+
return path_approximator.approximate_catmull(sub_control_points)
|
| 114 |
+
else:
|
| 115 |
+
return path_approximator.approximate_bezier(sub_control_points)
|
| 116 |
+
|
| 117 |
+
def calculate_path(self) -> None:
|
| 118 |
+
self.calculated_path.clear()
|
| 119 |
+
|
| 120 |
+
start = 0
|
| 121 |
+
end = 0
|
| 122 |
+
|
| 123 |
+
for i in range(len(self.get_control_points())):
|
| 124 |
+
end += 1
|
| 125 |
+
|
| 126 |
+
if (
|
| 127 |
+
i == len(self.get_control_points()) - 1
|
| 128 |
+
or (
|
| 129 |
+
self.get_control_points()[i] == self.get_control_points()[i + 1]
|
| 130 |
+
).all()
|
| 131 |
+
):
|
| 132 |
+
cp_span = self.get_control_points()[start:end]
|
| 133 |
+
|
| 134 |
+
for t in self.calculate_subpath(cp_span):
|
| 135 |
+
if (
|
| 136 |
+
len(self.calculated_path) == 0
|
| 137 |
+
or (self.calculated_path[-1] != t).any()
|
| 138 |
+
):
|
| 139 |
+
self.calculated_path.append(t)
|
| 140 |
+
|
| 141 |
+
start = end
|
| 142 |
+
|
| 143 |
+
def calculate_cumulative_length(self) -> None:
|
| 144 |
+
length = 0
|
| 145 |
+
|
| 146 |
+
self.cumulative_length.clear()
|
| 147 |
+
self.cumulative_length.append(length)
|
| 148 |
+
|
| 149 |
+
for i in range(len(self.calculated_path) - 1):
|
| 150 |
+
diff = self.calculated_path[i + 1] - self.calculated_path[i]
|
| 151 |
+
d = norm(diff)
|
| 152 |
+
|
| 153 |
+
if (
|
| 154 |
+
self.expected_distance is not None
|
| 155 |
+
and self.expected_distance - length < d
|
| 156 |
+
):
|
| 157 |
+
self.calculated_path[i + 1] = (
|
| 158 |
+
self.calculated_path[i]
|
| 159 |
+
+ diff * (self.expected_distance - length) / d
|
| 160 |
+
)
|
| 161 |
+
del self.calculated_path[i + 2 : len(self.calculated_path) - 2 - i]
|
| 162 |
+
|
| 163 |
+
length = self.expected_distance
|
| 164 |
+
self.cumulative_length.append(length)
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
length += d
|
| 168 |
+
self.cumulative_length.append(length)
|
| 169 |
+
|
| 170 |
+
if (
|
| 171 |
+
self.expected_distance is not None
|
| 172 |
+
and length < self.expected_distance
|
| 173 |
+
and len(self.calculated_path) > 1
|
| 174 |
+
):
|
| 175 |
+
diff = self.calculated_path[-1] - self.calculated_path[-2]
|
| 176 |
+
d = norm(diff)
|
| 177 |
+
|
| 178 |
+
if d <= 0:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
self.calculated_path[-1] += (
|
| 182 |
+
diff * (self.expected_distance - self.cumulative_length[-1]) / d
|
| 183 |
+
)
|
| 184 |
+
self.cumulative_length[-1] = self.expected_distance
|
| 185 |
+
|
| 186 |
+
def index_of_distance(self, d) -> int:
|
| 187 |
+
i = binary_search(self.cumulative_length, d)
|
| 188 |
+
if i < 0:
|
| 189 |
+
i = ~i
|
| 190 |
+
|
| 191 |
+
return i
|
| 192 |
+
|
| 193 |
+
def progress_to_distance(self, progress) -> float:
|
| 194 |
+
return np.clip(progress, 0, 1) * self.get_distance()
|
| 195 |
+
|
| 196 |
+
def interpolate_vertices(self, i, d) -> np.array:
|
| 197 |
+
if len(self.calculated_path) == 0:
|
| 198 |
+
return np.zeros([2])
|
| 199 |
+
|
| 200 |
+
if i <= 0:
|
| 201 |
+
return self.calculated_path[0]
|
| 202 |
+
if i >= len(self.calculated_path):
|
| 203 |
+
return self.calculated_path[-1]
|
| 204 |
+
|
| 205 |
+
p0 = self.calculated_path[i - 1]
|
| 206 |
+
p1 = self.calculated_path[i]
|
| 207 |
+
|
| 208 |
+
d0 = self.cumulative_length[i - 1]
|
| 209 |
+
d1 = self.cumulative_length[i]
|
| 210 |
+
|
| 211 |
+
if np.isclose(d0, d1):
|
| 212 |
+
return p0
|
| 213 |
+
|
| 214 |
+
w = (d - d0) / (d1 - d0)
|
| 215 |
+
return p0 + (p1 - p0) * w
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
path = SliderPath(
|
| 220 |
+
"Bezier",
|
| 221 |
+
100 * np.array([[0, 0], [1, 1], [1, -1], [2, 0], [2, 0], [3, -1], [2, -2]]),
|
| 222 |
+
)
|
| 223 |
+
p = np.vstack(path.calculated_path)
|
| 224 |
+
logging.info(p.shape)
|
| 225 |
+
|
| 226 |
+
import matplotlib.pyplot as plt
|
| 227 |
+
|
| 228 |
+
plt.axis("equal")
|
| 229 |
+
plt.plot(p[:, 0], p[:, 1], color="green")
|
| 230 |
+
plt.show()
|
osuT5/inference/template.osu
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
osu file format v14
|
| 2 |
+
|
| 3 |
+
[General]
|
| 4 |
+
AudioFilename: $audio_filename
|
| 5 |
+
AudioLeadIn: 0
|
| 6 |
+
PreviewTime: -1
|
| 7 |
+
Countdown: 0
|
| 8 |
+
SampleSet: Soft
|
| 9 |
+
StackLeniency: 0.7
|
| 10 |
+
Mode: 0
|
| 11 |
+
LetterboxInBreaks: 0
|
| 12 |
+
WidescreenStoryboard: 1
|
| 13 |
+
|
| 14 |
+
[Editor]
|
| 15 |
+
DistanceSpacing: 1.0
|
| 16 |
+
BeatDivisor: 4
|
| 17 |
+
GridSize: 8
|
| 18 |
+
TimelineZoom: 1
|
| 19 |
+
|
| 20 |
+
[Metadata]
|
| 21 |
+
Title:$title
|
| 22 |
+
TitleUnicode:$title_unicode
|
| 23 |
+
Artist:$artist
|
| 24 |
+
ArtistUnicode:$artist_unicode
|
| 25 |
+
Creator:$creator
|
| 26 |
+
Version:$version
|
| 27 |
+
Source:
|
| 28 |
+
Tags:
|
| 29 |
+
|
| 30 |
+
[Difficulty]
|
| 31 |
+
HPDrainRate:$hp_drain_rate
|
| 32 |
+
CircleSize:$circle_size
|
| 33 |
+
OverallDifficulty:$overall_difficulty
|
| 34 |
+
ApproachRate:$approach_rate
|
| 35 |
+
SliderMultiplier:$slider_multiplier
|
| 36 |
+
SliderTickRate:1
|
| 37 |
+
|
| 38 |
+
[Events]
|
| 39 |
+
//Background and Video events
|
| 40 |
+
//Break Periods
|
| 41 |
+
//Storyboard Layer 0 (Background)
|
| 42 |
+
//Storyboard Layer 1 (Fail)
|
| 43 |
+
//Storyboard Layer 2 (Pass)
|
| 44 |
+
//Storyboard Layer 3 (Foreground)
|
| 45 |
+
//Storyboard Layer 4 (Overlay)
|
| 46 |
+
//Storyboard Sound Samples
|
| 47 |
+
|
| 48 |
+
[TimingPoints]
|
| 49 |
+
$timing_points
|
| 50 |
+
|
| 51 |
+
[Colours]
|
| 52 |
+
|
| 53 |
+
[HitObjects]
|
| 54 |
+
$hit_objects
|
osuT5/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .osu_t import OsuT
|
osuT5/model/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (203 Bytes). View file
|
|
|
osuT5/model/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|