Tiger14n commited on
Commit
7ef7abb
·
verified ·
1 Parent(s): 52b0e2d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -35
  2. README.md +3 -0
  3. checkpoint/custom_checkpoint_0.pkl +3 -0
  4. checkpoint/pytorch_model.bin +3 -0
  5. configs/inference.yaml +65 -0
  6. configs/model/model.yaml +8 -0
  7. configs/model/t5_base.yaml +15 -0
  8. configs/model/t5_small.yaml +10 -0
  9. configs/model/t5_small_v4.yaml +7 -0
  10. configs/model/t5_small_v9.yaml +9 -0
  11. configs/model/whisper_base.yaml +6 -0
  12. inference.py +117 -0
  13. osuT5/__init__.py +0 -0
  14. osuT5/__pycache__/__init__.cpython-311.pyc +0 -0
  15. osuT5/__pycache__/__init__.cpython-39.pyc +0 -0
  16. osuT5/dataset/__init__.py +1 -0
  17. osuT5/dataset/__pycache__/__init__.cpython-311.pyc +0 -0
  18. osuT5/dataset/__pycache__/__init__.cpython-39.pyc +0 -0
  19. osuT5/dataset/__pycache__/data_utils.cpython-311.pyc +0 -0
  20. osuT5/dataset/__pycache__/data_utils.cpython-39.pyc +0 -0
  21. osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc +0 -0
  22. osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc +0 -0
  23. osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc +0 -0
  24. osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc +0 -0
  25. osuT5/dataset/data_utils.py +100 -0
  26. osuT5/dataset/osu_parser.py +184 -0
  27. osuT5/inference/__init__.py +4 -0
  28. osuT5/inference/__pycache__/__init__.cpython-311.pyc +0 -0
  29. osuT5/inference/__pycache__/__init__.cpython-39.pyc +0 -0
  30. osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc +0 -0
  31. osuT5/inference/__pycache__/path_approximator.cpython-311.pyc +0 -0
  32. osuT5/inference/__pycache__/path_approximator.cpython-39.pyc +0 -0
  33. osuT5/inference/__pycache__/pipeline.cpython-311.pyc +0 -0
  34. osuT5/inference/__pycache__/pipeline.cpython-39.pyc +0 -0
  35. osuT5/inference/__pycache__/postprocessor.cpython-311.pyc +0 -0
  36. osuT5/inference/__pycache__/postprocessor.cpython-39.pyc +0 -0
  37. osuT5/inference/__pycache__/preprocessor.cpython-311.pyc +0 -0
  38. osuT5/inference/__pycache__/preprocessor.cpython-39.pyc +0 -0
  39. osuT5/inference/__pycache__/slider_path.cpython-311.pyc +0 -0
  40. osuT5/inference/__pycache__/slider_path.cpython-39.pyc +0 -0
  41. osuT5/inference/diffusion_pipeline.py +214 -0
  42. osuT5/inference/path_approximator.py +253 -0
  43. osuT5/inference/pipeline.py +338 -0
  44. osuT5/inference/postprocessor.py +322 -0
  45. osuT5/inference/preprocessor.py +58 -0
  46. osuT5/inference/slider_path.py +230 -0
  47. osuT5/inference/template.osu +54 -0
  48. osuT5/model/__init__.py +1 -0
  49. osuT5/model/__pycache__/__init__.cpython-311.pyc +0 -0
  50. osuT5/model/__pycache__/__init__.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ osuT5/inference/vale.mp3 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
checkpoint/custom_checkpoint_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0494fdd396142b4a2919c0ab913502c9335746959156a08e82c2647235e07853
3
+ size 564880
checkpoint/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a12b6c312590efbdf5d7acaff6d8537e8ad1728737eebb43d0a43d5a4b3b5a3a
3
+ size 377860126
configs/inference.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: 'google/t5-v1_1-small'
3
+ spectrogram:
4
+ sample_rate: 16000
5
+ hop_length: 128
6
+ n_fft: 1024
7
+ n_mels: 388
8
+ do_style_embed: false
9
+ input_features: false
10
+
11
+ model_path: './checkpoint'
12
+ audio_path: '' # Path to input audio
13
+ total_duration_ms: 0 # Total duration of audio in milliseconds, 0 for full audio
14
+ output_path: '' # Path to output directory
15
+ bpm: 120 # Beats per minute of input audio
16
+ offset: 0 # Start of beat, in miliseconds, from the beginning of input audio
17
+ resnap_objects: false # Resnap objects beat timing ticks, requires accurate BPM and offset
18
+ slider_multiplier: 1.7 # Multiplier for slider velocity
19
+ title: '' # Song title
20
+ artist: '' # Song artist
21
+ beatmap_path: '' # Path to .osu file which will be remapped
22
+ other_beatmap_path: '' # Path to .osu file of other beatmap in the mapset to use as reference
23
+ beatmap_id: -1 # Beatmap ID to use as style
24
+ difficulty: -1 # Difficulty star rating to map
25
+ creator: '' # Beatmap creator
26
+ version: '' # Beatmap version
27
+ full_set: true # Generate full mapset
28
+ set_difficulties: 5 # Number of difficulties to generate.
29
+
30
+ # Diffusion settings
31
+ generate_positions: true # Use diffusion to generate object positions
32
+ diff_ckpt: './osudiffusion/DiT-B-0700000.pt' # Path to checkpoint for diffusion model
33
+ diff_refine_ckpt: '' # Path to checkpoint for refining diffusion model
34
+
35
+ diffusion:
36
+ style_id: 1451282 # Style ID to use for diffusion
37
+ num_sampling_steps: 100 # Number of sampling steps
38
+ cfg_scale: 1 # Scale of classifier-free guidance
39
+ num_classes: 52670 # Number of classes stored in the model
40
+ beatmap_idx: 'osudiffusion/beatmap_idx.pickle' # Path to beatmap index
41
+ use_amp: true # Use automatic mixed precision
42
+ refine_iters: 10 # Number of refinement iterations
43
+ seq_len: 128 # Sequence length
44
+ model: 'DiT-B' # Model architecture
45
+
46
+
47
+ data: # Data settings
48
+ src_seq_len: 640
49
+ tgt_seq_len: 480
50
+ sample_rate: ${model.spectrogram.sample_rate}
51
+ hop_length: ${model.spectrogram.hop_length}
52
+ sequence_stride: 1 # Fraction of audio sequence length to shift inference window
53
+ center_pad_decoder: false # Center pad decoder input
54
+ add_pre_tokens: true
55
+ special_token_len: 2
56
+ diff_token_index: 0
57
+ style_token_index: -1
58
+ max_pre_token_len: 4
59
+ add_gd_context: false # Prefix the decoder with tokens of another beatmap in the mapset
60
+
61
+ hydra:
62
+ job:
63
+ chdir: False
64
+ run:
65
+ dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
configs/model/model.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ input_features: false
2
+ do_style_embed: true
3
+
4
+ spectrogram:
5
+ sample_rate: 16000
6
+ hop_length: 128
7
+ n_fft: 1024
8
+ n_mels: 388
configs/model/t5_base.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'google/t5-v1_1-base'
6
+ overwrite:
7
+ dropout_rate: 0.0
8
+
9
+ spectrogram:
10
+ sample_rate: 16000
11
+ hop_length: 128
12
+ n_fft: 1024
13
+ n_mels: 388
14
+ do_style_embed: false
15
+ input_features: false
configs/model/t5_small.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'google/t5-v1_1-small'
6
+ overwrite:
7
+ dropout_rate: 0.0
8
+
9
+ spectrogram:
10
+ n_mels: 512
configs/model/t5_small_v4.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'google/t5-v1_1-small'
6
+ overwrite:
7
+ dropout_rate: 0.0
configs/model/t5_small_v9.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ do_style_embed: false
6
+
7
+ name: 'google/t5-v1_1-small'
8
+ overwrite:
9
+ dropout_rate: 0.0
configs/model/whisper_base.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'openai/whisper-base'
6
+ input_features: true
inference.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import hydra
4
+ import torch
5
+ from omegaconf import DictConfig
6
+ from slider import Beatmap
7
+
8
+ from osudiffusion import DiT_models
9
+ from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipeline
10
+ from osuT5.tokenizer import Tokenizer
11
+ from osuT5.utils import get_model
12
+
13
+
14
+ def get_args_from_beatmap(args: DictConfig):
15
+ if args.beatmap_path is None or args.beatmap_path == "":
16
+ return
17
+
18
+ beatmap_path = Path(args.beatmap_path)
19
+
20
+ if not beatmap_path.is_file():
21
+ raise FileNotFoundError(f"Beatmap file {beatmap_path} not found.")
22
+
23
+ beatmap = Beatmap.from_path(beatmap_path)
24
+ args.audio_path = beatmap_path.parent / beatmap.audio_filename
25
+ args.output_path = beatmap_path.parent
26
+ args.bpm = beatmap.bpm_max()
27
+ args.offset = min(tp.offset.total_seconds() * 1000 for tp in beatmap.timing_points)
28
+ args.slider_multiplier = beatmap.slider_multiplier
29
+ args.title = beatmap.title
30
+ args.artist = beatmap.artist
31
+ args.beatmap_id = beatmap.beatmap_id if args.beatmap_id == -1 else args.beatmap_id
32
+ args.diffusion.style_id = beatmap.beatmap_id if args.diffusion.style_id == -1 else args.diffusion.style_id
33
+ args.difficulty = float(beatmap.stars()) if args.difficulty == -1 else args.difficulty
34
+
35
+
36
+ def find_model(ckpt_path, args: DictConfig, device):
37
+ assert Path(ckpt_path).exists(), f"Could not find DiT checkpoint at {ckpt_path}"
38
+ checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
39
+ if "ema" in checkpoint: # supports checkpoints from train.py
40
+ checkpoint = checkpoint["ema"]
41
+
42
+ model = DiT_models[args.diffusion.model](
43
+ num_classes=args.diffusion.num_classes,
44
+ context_size=19 - 3 + 128,
45
+ ).to(device)
46
+ model.load_state_dict(checkpoint)
47
+ model.eval() # important!
48
+ return model
49
+
50
+
51
+ @hydra.main(config_path="configs", config_name="inference", version_base="1.1")
52
+ def main(args: DictConfig):
53
+ get_args_from_beatmap(args)
54
+
55
+ torch.set_grad_enabled(False)
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ ckpt_path = Path(args.model_path)
58
+ model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device)
59
+ tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl")
60
+
61
+ tokenizer = Tokenizer()
62
+ tokenizer.load_state_dict(tokenizer_state)
63
+
64
+ model = get_model(args, tokenizer)
65
+ model.load_state_dict(model_state)
66
+ model.eval()
67
+ model.to(device)
68
+
69
+ preprocessor = Preprocessor(args)
70
+ audio = preprocessor.load(args.audio_path)
71
+ sequences = preprocessor.segment(audio)
72
+ total_duration_ms = len(audio) / 16000 * 1000
73
+ args.total_duration_ms = total_duration_ms
74
+
75
+
76
+
77
+
78
+
79
+ generated_maps = []
80
+ generated_positions = []
81
+ diffs = []
82
+
83
+
84
+ if args.full_set:
85
+ for i in range(args.set_difficulties):
86
+ diffs.append(3 + i * (7 - 3) / (args.set_difficulties - 1))
87
+
88
+ print(diffs)
89
+ for diff in diffs:
90
+ print(f"Generating difficulty {diff}")
91
+ args.difficulty = diff
92
+ pipeline = Pipeline(args, tokenizer)
93
+ events = pipeline.generate(model, sequences)
94
+ generated_maps.append(events)
95
+ else:
96
+ pipeline = Pipeline(args, tokenizer)
97
+ events = pipeline.generate(model, sequences)
98
+ generated_maps.append(events)
99
+
100
+
101
+
102
+ if args.generate_positions:
103
+ model = find_model(args.diff_ckpt, args, device)
104
+ refine_model = find_model(args.diff_refine_ckpt, args, device) if len(args.diff_refine_ckpt) > 0 else None
105
+ diffusion_pipeline = DiffisionPipeline(args.diffusion)
106
+ for events in generated_maps:
107
+ events = diffusion_pipeline.generate(model, events, refine_model)
108
+ generated_positions.append(events)
109
+ else:
110
+ generated_positions = generated_maps
111
+
112
+ postprocessor = Postprocessor(args)
113
+ postprocessor.generate(generated_positions)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
osuT5/__init__.py ADDED
File without changes
osuT5/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (149 Bytes). View file
 
osuT5/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (131 Bytes). View file
 
osuT5/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .osu_parser import OsuParser
osuT5/dataset/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (277 Bytes). View file
 
osuT5/dataset/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (229 Bytes). View file
 
osuT5/dataset/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (3.77 kB). View file
 
osuT5/dataset/__pycache__/data_utils.cpython-39.pyc ADDED
Binary file (2.11 kB). View file
 
osuT5/dataset/__pycache__/ors_dataset.cpython-311.pyc ADDED
Binary file (30 kB). View file
 
osuT5/dataset/__pycache__/ors_dataset.cpython-39.pyc ADDED
Binary file (16.1 kB). View file
 
osuT5/dataset/__pycache__/osu_parser.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
osuT5/dataset/__pycache__/osu_parser.cpython-39.pyc ADDED
Binary file (6.51 kB). View file
 
osuT5/dataset/data_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ from pydub import AudioSegment
6
+
7
+ import numpy.typing as npt
8
+
9
+ from osuT5.tokenizer import Event, EventType
10
+
11
+ MILISECONDS_PER_SECOND = 1000
12
+
13
+
14
+ def load_audio_file(file: Path, sample_rate: int) -> npt.NDArray:
15
+ """Load an audio file as a numpy time-series array
16
+
17
+ The signals are resampled, converted to mono channel, and normalized.
18
+
19
+ Args:
20
+ file: Path to audio file.
21
+ sample_rate: Sample rate to resample the audio.
22
+
23
+ Returns:
24
+ samples: Audio time series.
25
+ """
26
+ print(file)
27
+ audio = AudioSegment.from_file(file, format="mp3")
28
+ audio = audio.set_frame_rate(sample_rate)
29
+ audio = audio.set_channels(1)
30
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32)
31
+ samples *= 1.0 / np.max(np.abs(samples))
32
+ return samples
33
+
34
+
35
+ def update_event_times(events: list[Event], event_times: list[float], end_time: Optional[float] = None):
36
+ non_timed_events = [
37
+ EventType.BEZIER_ANCHOR,
38
+ EventType.PERFECT_ANCHOR,
39
+ EventType.CATMULL_ANCHOR,
40
+ EventType.RED_ANCHOR,
41
+ ]
42
+ timed_events = [
43
+ EventType.CIRCLE,
44
+ EventType.SPINNER,
45
+ EventType.SPINNER_END,
46
+ EventType.SLIDER_HEAD,
47
+ EventType.LAST_ANCHOR,
48
+ EventType.SLIDER_END,
49
+ ]
50
+
51
+ start_index = len(event_times)
52
+ end_index = len(events)
53
+ ct = 0 if len(event_times) == 0 else event_times[-1]
54
+ for i in range(start_index, end_index):
55
+ event = events[i]
56
+ if event.type == EventType.TIME_SHIFT:
57
+ ct = event.value
58
+ event_times.append(ct)
59
+
60
+ # Interpolate time for control point events
61
+ # T-D-Start-D-CP-D-CP-T-D-LCP-T-D-End
62
+ # 1-1-1-----1-1--1-1--7-7--7--9-9-9--
63
+ # 1-1-1-----3-3--5-5--7-7--7--9-9-9--
64
+ ct = end_time if end_time is not None else event_times[-1]
65
+ interpolate = False
66
+ for i in range(end_index - 1, start_index - 1, -1):
67
+ event = events[i]
68
+
69
+ if event.type in timed_events:
70
+ interpolate = False
71
+
72
+ if event.type in non_timed_events:
73
+ interpolate = True
74
+
75
+ if not interpolate:
76
+ ct = event_times[i]
77
+ continue
78
+
79
+ if event.type not in non_timed_events:
80
+ event_times[i] = ct
81
+ continue
82
+
83
+ # Find the time of the first timed event and the number of control points between
84
+ j = i
85
+ count = 0
86
+ t = ct
87
+ while j >= 0:
88
+ event2 = events[j]
89
+ if event2.type == EventType.TIME_SHIFT:
90
+ t = event_times[j]
91
+ break
92
+ if event2.type in non_timed_events:
93
+ count += 1
94
+ j -= 1
95
+ if i < 0:
96
+ t = 0
97
+
98
+ # Interpolate the time
99
+ ct = (ct - t) / (count + 1) * count + t
100
+ event_times[i] = ct
osuT5/dataset/osu_parser.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from datetime import timedelta
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ from slider import Beatmap, Circle, Slider, Spinner
8
+ from slider.curve import Linear, Catmull, Perfect, MultiBezier
9
+
10
+ from osuT5.tokenizer import Event, EventType, Tokenizer
11
+
12
+
13
+ class OsuParser:
14
+ def __init__(self, tokenizer: Tokenizer) -> None:
15
+ dist_range = tokenizer.event_range[EventType.DISTANCE]
16
+ self.dist_min = dist_range.min_value
17
+ self.dist_max = dist_range.max_value
18
+
19
+ def parse(self, beatmap: Beatmap) -> list[Event]:
20
+ # noinspection PyUnresolvedReferences
21
+ """Parse an .osu beatmap.
22
+
23
+ Each hit object is parsed into a list of Event objects, in order of its
24
+ appearance in the beatmap. In other words, in ascending order of time.
25
+
26
+ Args:
27
+ beatmap: Beatmap object parsed from an .osu file.
28
+
29
+ Returns:
30
+ events: List of Event object lists.
31
+
32
+ Example::
33
+ >>> beatmap = [
34
+ "64,80,11000,1,0",
35
+ "100,100,16000,2,0,B|200:200|250:200|250:200|300:150,2"
36
+ ]
37
+ >>> events = parse(beatmap)
38
+ >>> print(events)
39
+ [
40
+ Event(EventType.TIME_SHIFT, 11000), Event(EventType.DISTANCE, 36), Event(EventType.CIRCLE),
41
+ Event(EventType.TIME_SHIFT, 16000), Event(EventType.DISTANCE, 42), Event(EventType.SLIDER_HEAD),
42
+ Event(EventType.TIME_SHIFT, 16500), Event(EventType.DISTANCE, 141), Event(EventType.BEZIER_ANCHOR),
43
+ Event(EventType.TIME_SHIFT, 17000), Event(EventType.DISTANCE, 50), Event(EventType.BEZIER_ANCHOR),
44
+ Event(EventType.TIME_SHIFT, 17500), Event(EventType.DISTANCE, 10), Event(EventType.BEZIER_ANCHOR),
45
+ Event(EventType.TIME_SHIFT, 18000), Event(EventType.DISTANCE, 64), Event(EventType.LAST _ANCHOR),
46
+ Event(EventType.TIME_SHIFT, 20000), Event(EventType.DISTANCE, 11), Event(EventType.SLIDER_END)
47
+ ]
48
+ """
49
+ hit_objects = beatmap.hit_objects(stacking=False)
50
+ last_pos = np.array((256, 192))
51
+ events = []
52
+
53
+ for hit_object in hit_objects:
54
+ if isinstance(hit_object, Circle):
55
+ last_pos = self._parse_circle(hit_object, events, last_pos)
56
+ elif isinstance(hit_object, Slider):
57
+ last_pos = self._parse_slider(hit_object, events, last_pos)
58
+ elif isinstance(hit_object, Spinner):
59
+ last_pos = self._parse_spinner(hit_object, events)
60
+
61
+ return events
62
+
63
+ def _clip_dist(self, dist: int) -> int:
64
+ """Clip distance to valid range."""
65
+ return int(np.clip(dist, self.dist_min, self.dist_max))
66
+
67
+ def _parse_circle(self, circle: Circle, events: list[Event], last_pos: npt.NDArray) -> npt.NDArray:
68
+ """Parse a circle hit object.
69
+
70
+ Args:
71
+ circle: Circle object.
72
+ events: List of events to add to.
73
+ last_pos: Last position of the hit objects.
74
+
75
+ Returns:
76
+ pos: Position of the circle.
77
+ """
78
+ time = int(circle.time.total_seconds() * 1000)
79
+ pos = np.array(circle.position)
80
+ dist = self._clip_dist(np.linalg.norm(pos - last_pos))
81
+
82
+ events.append(Event(EventType.TIME_SHIFT, time))
83
+ events.append(Event(EventType.DISTANCE, dist))
84
+ if circle.new_combo:
85
+ events.append(Event(EventType.NEW_COMBO))
86
+ events.append(Event(EventType.CIRCLE))
87
+
88
+ return pos
89
+
90
+ def _parse_slider(self, slider: Slider, events: list[Event], last_pos: npt.NDArray) -> npt.NDArray:
91
+ """Parse a slider hit object.
92
+
93
+ Args:
94
+ slider: Slider object.
95
+ events: List of events to add to.
96
+ last_pos: Last position of the hit objects.
97
+
98
+ Returns:
99
+ pos: Last position of the slider.
100
+ """
101
+ # Ignore sliders which are too big
102
+ if len(slider.curve.points) >= 100:
103
+ return last_pos
104
+
105
+ time = int(slider.time.total_seconds() * 1000)
106
+ pos = np.array(slider.position)
107
+ dist = self._clip_dist(np.linalg.norm(pos - last_pos))
108
+ last_pos = pos
109
+
110
+ events.append(Event(EventType.TIME_SHIFT, time))
111
+ events.append(Event(EventType.DISTANCE, dist))
112
+ if slider.new_combo:
113
+ events.append(Event(EventType.NEW_COMBO))
114
+ events.append(Event(EventType.SLIDER_HEAD))
115
+
116
+ duration: timedelta = (slider.end_time - slider.time) / slider.repeat
117
+ control_point_count = len(slider.curve.points)
118
+
119
+ def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray:
120
+ for i in range(1, control_point_count - 1):
121
+ last_pos = add_anchor_time_dist(i, last_pos)
122
+ events.append(Event(event_type))
123
+
124
+ return last_pos
125
+
126
+ def add_anchor_time_dist(i: int, last_pos: npt.NDArray) -> npt.NDArray:
127
+ time = int((slider.time + i / (control_point_count - 1) * duration).total_seconds() * 1000)
128
+ pos = np.array(slider.curve.points[i])
129
+ dist = self._clip_dist(np.linalg.norm(pos - last_pos))
130
+ last_pos = pos
131
+
132
+ events.append(Event(EventType.TIME_SHIFT, time))
133
+ events.append(Event(EventType.DISTANCE, dist))
134
+
135
+ return last_pos
136
+
137
+ if isinstance(slider.curve, Linear):
138
+ last_pos = append_control_points(EventType.RED_ANCHOR, last_pos)
139
+ elif isinstance(slider.curve, Catmull):
140
+ last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos)
141
+ elif isinstance(slider.curve, Perfect):
142
+ last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos)
143
+ elif isinstance(slider.curve, MultiBezier):
144
+ for i in range(1, control_point_count - 1):
145
+ if slider.curve.points[i] == slider.curve.points[i + 1]:
146
+ last_pos = add_anchor_time_dist(i, last_pos)
147
+ events.append(Event(EventType.RED_ANCHOR))
148
+ elif slider.curve.points[i] != slider.curve.points[i - 1]:
149
+ last_pos = add_anchor_time_dist(i, last_pos)
150
+ events.append(Event(EventType.BEZIER_ANCHOR))
151
+
152
+ last_pos = add_anchor_time_dist(control_point_count - 1, last_pos)
153
+ events.append(Event(EventType.LAST_ANCHOR))
154
+
155
+ time = int(slider.end_time.total_seconds() * 1000)
156
+ pos = np.array(slider.curve(1))
157
+ dist = self._clip_dist(np.linalg.norm(pos - last_pos))
158
+ last_pos = pos
159
+
160
+ events.append(Event(EventType.TIME_SHIFT, time))
161
+ events.append(Event(EventType.DISTANCE, dist))
162
+ events.append(Event(EventType.SLIDER_END))
163
+
164
+ return last_pos
165
+
166
+ def _parse_spinner(self, spinner: Spinner, events: list[Event]) -> npt.NDArray:
167
+ """Parse a spinner hit object.
168
+
169
+ Args:
170
+ spinner: Spinner object.
171
+ events: List of events to add to.
172
+
173
+ Returns:
174
+ pos: Last position of the spinner.
175
+ """
176
+ time = int(spinner.time.total_seconds() * 1000)
177
+ events.append(Event(EventType.TIME_SHIFT, time))
178
+ events.append(Event(EventType.SPINNER))
179
+
180
+ time = int(spinner.end_time.total_seconds() * 1000)
181
+ events.append(Event(EventType.TIME_SHIFT, time))
182
+ events.append(Event(EventType.SPINNER_END))
183
+
184
+ return np.array((256, 192))
osuT5/inference/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .pipeline import *
2
+ from .preprocessor import *
3
+ from .postprocessor import *
4
+ from .diffusion_pipeline import *
osuT5/inference/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (301 Bytes). View file
 
osuT5/inference/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (245 Bytes). View file
 
osuT5/inference/__pycache__/diffusion_pipeline.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
osuT5/inference/__pycache__/path_approximator.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
osuT5/inference/__pycache__/path_approximator.cpython-39.pyc ADDED
Binary file (5.02 kB). View file
 
osuT5/inference/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
osuT5/inference/__pycache__/pipeline.cpython-39.pyc ADDED
Binary file (11.8 kB). View file
 
osuT5/inference/__pycache__/postprocessor.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
osuT5/inference/__pycache__/postprocessor.cpython-39.pyc ADDED
Binary file (8.08 kB). View file
 
osuT5/inference/__pycache__/preprocessor.cpython-311.pyc ADDED
Binary file (3.37 kB). View file
 
osuT5/inference/__pycache__/preprocessor.cpython-39.pyc ADDED
Binary file (2.23 kB). View file
 
osuT5/inference/__pycache__/slider_path.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
osuT5/inference/__pycache__/slider_path.cpython-39.pyc ADDED
Binary file (5.16 kB). View file
 
osuT5/inference/diffusion_pipeline.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from omegaconf import DictConfig
6
+ from tqdm import tqdm
7
+
8
+ from osudiffusion import timestep_embedding
9
+ from osudiffusion import repeat_type
10
+ from osudiffusion import create_diffusion
11
+ from osudiffusion import DiT
12
+ from osuT5.dataset.data_utils import update_event_times
13
+ from osuT5.tokenizer import Event, EventType
14
+
15
+
16
+ def get_beatmap_idx(path) -> dict[int, int]:
17
+ p = Path(path)
18
+ with p.open("rb") as f:
19
+ beatmap_idx = pickle.load(f)
20
+ return beatmap_idx
21
+
22
+
23
+ class DiffisionPipeline(object):
24
+ def __init__(self, args: DictConfig):
25
+ """Model inference stage that generates positions for distance events."""
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.num_sampling_steps = args.num_sampling_steps
28
+ self.cfg_scale = args.cfg_scale
29
+ self.seq_len = args.seq_len
30
+ self.num_classes = args.num_classes
31
+ self.beatmap_idx = get_beatmap_idx(args.beatmap_idx)
32
+ self.style_id = args.style_id
33
+ self.refine_iters = args.refine_iters
34
+ self.use_amp = args.use_amp
35
+
36
+ if self.style_id in self.beatmap_idx:
37
+ self.class_label = self.beatmap_idx[self.style_id]
38
+ else:
39
+ print(f"Beatmap ID {self.style_id} not found in dataset, using default style.")
40
+ self.class_label = self.num_classes
41
+
42
+ def generate(self, model: DiT, events: list[Event], refine_model: DiT = None) -> list[Event]:
43
+ """Generate position events for distance events in the Event list.
44
+
45
+ Args:
46
+ model: Trained model to use for inference.
47
+ events: List of Event objects with distance events.
48
+ refine_model: Optional model to refine the generated positions.
49
+
50
+ Returns:
51
+ events: List of Event objects with position events.
52
+ """
53
+
54
+ seq_o, seq_c, seq_len, seq_indices = self.events_to_sequence(events)
55
+
56
+ seq_o = seq_o - seq_o[0] # Normalize to relative time
57
+ print(f"seq len {seq_len}")
58
+
59
+ diffusion = create_diffusion(
60
+ str(self.num_sampling_steps),
61
+ noise_schedule="squaredcos_cap_v2",
62
+ )
63
+
64
+ # Create banded matrix attention mask for increased sequence length
65
+ attn_mask = torch.full((seq_len, seq_len), True, dtype=torch.bool, device=self.device)
66
+ for i in range(seq_len):
67
+ attn_mask[max(0, i - self.seq_len): min(seq_len, i + self.seq_len), i] = False
68
+
69
+ class_labels = [self.class_label]
70
+
71
+ # Create sampling noise:
72
+ n = len(class_labels)
73
+ z = torch.randn(n, 2, seq_len, device=self.device)
74
+ o = seq_o.repeat(n, 1).to(self.device)
75
+ c = seq_c.repeat(n, 1, 1).to(self.device)
76
+ y = torch.tensor(class_labels, device=self.device)
77
+
78
+ # Setup classifier-free guidance:
79
+ z = torch.cat([z, z], 0)
80
+ o = torch.cat([o, o], 0)
81
+ c = torch.cat([c, c], 0)
82
+ y_null = torch.tensor([self.num_classes] * n, device=self.device)
83
+ y = torch.cat([y, y_null], 0)
84
+ model_kwargs = dict(o=o, c=c, y=y, cfg_scale=self.cfg_scale, attn_mask=attn_mask)
85
+
86
+ def to_positions(samples):
87
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
88
+ samples *= torch.tensor((512, 384), device=self.device).repeat(n, 1).unsqueeze(2)
89
+ return samples.cpu()
90
+
91
+ # Sample images:
92
+ samples = diffusion.p_sample_loop(
93
+ model.forward_with_cfg,
94
+ z.shape,
95
+ z,
96
+ clip_denoised=True,
97
+ model_kwargs=model_kwargs,
98
+ progress=True,
99
+ device=self.device,
100
+ )
101
+
102
+ if refine_model is not None:
103
+ # Refine result with refine model
104
+ for _ in tqdm(range(self.refine_iters)):
105
+ t = torch.tensor([0] * samples.shape[0], device=self.device)
106
+ with torch.no_grad():
107
+ out = diffusion.p_sample(
108
+ model.forward_with_cfg,
109
+ samples,
110
+ t,
111
+ clip_denoised=True,
112
+ model_kwargs=model_kwargs,
113
+ )
114
+ samples = out["sample"]
115
+
116
+ positions = to_positions(samples)
117
+ return self.events_with_pos(events, positions.squeeze(0), seq_indices)
118
+
119
+ @staticmethod
120
+ def events_to_sequence(events: list[Event]) -> tuple[torch.Tensor, torch.Tensor, int, dict[int, int]]:
121
+ # Calculate the time of every event and interpolate time for control point events
122
+ event_times = []
123
+ update_event_times(events, event_times)
124
+
125
+ # Calculate the number of repeats for each slider end event
126
+ # Convert to vectorized form for osu-diffusion
127
+ nc_types = [EventType.CIRCLE, EventType.SLIDER_HEAD]
128
+ event_index = {
129
+ EventType.CIRCLE: 0,
130
+ EventType.SPINNER: 2,
131
+ EventType.SPINNER_END: 3,
132
+ EventType.SLIDER_HEAD: 4,
133
+ EventType.BEZIER_ANCHOR: 6,
134
+ EventType.PERFECT_ANCHOR: 7,
135
+ EventType.CATMULL_ANCHOR: 8,
136
+ EventType.RED_ANCHOR: 9,
137
+ EventType.LAST_ANCHOR: 10,
138
+ EventType.SLIDER_END: 11,
139
+ }
140
+
141
+ seq_indices = {}
142
+ indices = []
143
+ data_chunks = []
144
+ distance = 0
145
+ new_combo = False
146
+ head_time = 0
147
+ last_anchor_time = 0
148
+ for i, event in enumerate(events):
149
+ indices.append(i)
150
+ if event.type == EventType.DISTANCE:
151
+ distance = event.value
152
+ elif event.type == EventType.NEW_COMBO:
153
+ new_combo = True
154
+ elif event.type in event_index:
155
+ time = event_times[i]
156
+ index = event_index[event.type]
157
+
158
+ # Handle NC index offset
159
+ if event.type in nc_types and new_combo:
160
+ index += 1
161
+ new_combo = False
162
+
163
+ # Add slider end repeats index offset
164
+ if event.type == EventType.SLIDER_END:
165
+ span_duration = last_anchor_time - head_time
166
+ total_duration = time - head_time
167
+ repeats = max(int(round(total_duration / span_duration)), 1) if span_duration > 0 else 1
168
+ index += repeat_type(repeats)
169
+ elif event.type == EventType.SLIDER_HEAD:
170
+ head_time = time
171
+ elif event.type == EventType.LAST_ANCHOR:
172
+ last_anchor_time = time
173
+
174
+ features = torch.zeros(18)
175
+ features[0] = time
176
+ features[1] = distance
177
+ features[index + 2] = 1
178
+ data_chunks.append(features)
179
+
180
+ for j in indices:
181
+ seq_indices[j] = len(data_chunks) - 1
182
+ indices = []
183
+
184
+ seq = torch.stack(data_chunks, 0)
185
+ seq = torch.swapaxes(seq, 0, 1)
186
+ seq_o = seq[0, :]
187
+ seq_d = seq[1, :]
188
+ seq_c = torch.concatenate(
189
+ [
190
+ timestep_embedding(seq_d, 128).T,
191
+ seq[2:, :],
192
+ ],
193
+ 0,
194
+ )
195
+
196
+ return seq_o, seq_c, seq.shape[1], seq_indices
197
+
198
+
199
+ @staticmethod
200
+ def events_with_pos(events: list[Event], sampled_seq: torch.Tensor, seq_indices: dict[int, int]) -> list[Event]:
201
+ new_events = []
202
+ for i, event in enumerate(events):
203
+ if event.type == EventType.DISTANCE:
204
+ try:
205
+ index = seq_indices[i]
206
+ pos_x = sampled_seq[0, index].item()
207
+ pos_y = sampled_seq[1, index].item()
208
+ new_events.append(Event(EventType.POS_X, int(round(pos_x))))
209
+ new_events.append(Event(EventType.POS_Y, int(round(pos_y))))
210
+ except KeyError:
211
+ print(f"Warning: Key {i} not found in seq_indices. Skipping event.")
212
+ else:
213
+ new_events.append(event)
214
+ return new_events
osuT5/inference/path_approximator.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ BEZIER_TOLERANCE = 0.25
4
+ CATMULL_DETAIL = 50
5
+ CIRCULAR_ARC_TOLERANCE = 0.1
6
+
7
+
8
+ length_squared = lambda x: np.inner(x, x)
9
+
10
+
11
+ def approximate_bezier(control_points: np.ndarray) -> np.ndarray:
12
+ return approximate_b_spline(control_points)
13
+
14
+
15
+ def approximate_b_spline(control_points: np.ndarray, p: int = 0) -> np.ndarray:
16
+ output = []
17
+ n = len(control_points) - 1
18
+
19
+ if n < 0:
20
+ return output
21
+
22
+ to_flatten = []
23
+ free_buffers = []
24
+
25
+ points = control_points.copy()
26
+
27
+ if 0 < p < n:
28
+ for i in range(n - p):
29
+ sub_bezier = np.empty((p + 1, 2))
30
+ sub_bezier[0] = points[i]
31
+
32
+ for j in range(p - 1):
33
+ sub_bezier[j + 1] = points[i + 1]
34
+
35
+ for k in range(1, p - j):
36
+ l = np.min((k, n - p - i))
37
+ points[i + k] = (l * points[i + k] + points[i + k + 1]) / (l + 1)
38
+
39
+ sub_bezier[p] = points[i + 1]
40
+ to_flatten.append(sub_bezier)
41
+
42
+ to_flatten.append(points[(n - p) :])
43
+ to_flatten.reverse()
44
+ else:
45
+ p = n
46
+ to_flatten.append(points)
47
+
48
+ subdivision_buffer1 = np.empty([p + 1, 2])
49
+ subdivision_buffer2 = np.empty([p * 2 + 1, 2])
50
+
51
+ left_child = subdivision_buffer2
52
+
53
+ while len(to_flatten) > 0:
54
+ parent = to_flatten.pop()
55
+
56
+ if bezier_is_flat_enough(parent):
57
+ bezier_approximate(
58
+ parent,
59
+ output,
60
+ subdivision_buffer1,
61
+ subdivision_buffer2,
62
+ p + 1,
63
+ )
64
+
65
+ free_buffers.append(parent)
66
+ continue
67
+
68
+ right_child = (
69
+ free_buffers.pop() if len(free_buffers) > 0 else np.empty([p + 1, 2])
70
+ )
71
+ bezier_subdivide(parent, left_child, right_child, subdivision_buffer1, p + 1)
72
+
73
+ for i in range(p + 1):
74
+ parent[i] = left_child[i]
75
+
76
+ to_flatten.append(right_child)
77
+ to_flatten.append(parent)
78
+
79
+ output.append(control_points[n].copy())
80
+ return np.vstack(output)
81
+
82
+
83
+ def approximate_catmull(control_points: np.ndarray) -> list[np.ndarray]:
84
+ result = []
85
+
86
+ for i in range(len(control_points) - 1):
87
+ v1 = control_points[i - 1] if i > 0 else control_points[i]
88
+ v2 = control_points[i]
89
+ v3 = control_points[i + 1] if i < len(control_points) - 1 else v2 + v2 - v1
90
+ v4 = control_points[i + 2] if i < len(control_points) - 2 else v3 + v3 - v2
91
+
92
+ for c in range(CATMULL_DETAIL):
93
+ result.append(catmull_find_point(v1, v2, v3, v4, c / CATMULL_DETAIL))
94
+ result.append(catmull_find_point(v1, v2, v3, v4, (c + 1) / CATMULL_DETAIL))
95
+
96
+ return result
97
+
98
+
99
+ def approximate_circular_arc(control_points: np.ndarray) -> list[np.ndarray]:
100
+ a = control_points[0]
101
+ b = control_points[1]
102
+ c = control_points[2]
103
+
104
+ aSq = length_squared(b - c)
105
+ bSq = length_squared(a - c)
106
+ cSq = length_squared(a - b)
107
+
108
+ if np.isclose(aSq, 0) or np.isclose(bSq, 0) or np.isclose(cSq, 0):
109
+ return []
110
+
111
+ s = aSq * (bSq + cSq - aSq)
112
+ t = bSq * (aSq + cSq - bSq)
113
+ u = cSq * (aSq + bSq - cSq)
114
+
115
+ sum = s + t + u
116
+
117
+ if np.isclose(sum, 0):
118
+ return []
119
+
120
+ centre = (s * a + t * b + u * c) / sum
121
+ dA = a - centre
122
+ dC = c - centre
123
+
124
+ r = np.linalg.norm(dA)
125
+
126
+ theta_start = np.arctan2(dA[1], dA[0])
127
+ theta_end = np.arctan2(dC[1], dC[0])
128
+
129
+ while theta_end < theta_start:
130
+ theta_end += 2 * np.pi
131
+
132
+ direction = 1
133
+ theta_range = theta_range = theta_end - theta_start
134
+
135
+ ortho_ato_c = c - a
136
+ ortho_ato_c = np.array([ortho_ato_c[1], -ortho_ato_c[0]])
137
+ if np.dot(ortho_ato_c, b - a) < 0:
138
+ direction = -direction
139
+ theta_range = 2 * np.pi - theta_range
140
+
141
+ amount_points = (
142
+ 2
143
+ if 2 * r <= CIRCULAR_ARC_TOLERANCE
144
+ else int(
145
+ max(
146
+ 2,
147
+ np.ceil(theta_range / (2 * np.arccos(1 - CIRCULAR_ARC_TOLERANCE / r))),
148
+ ),
149
+ )
150
+ )
151
+
152
+ output = []
153
+
154
+ for i in range(amount_points):
155
+ fract = i / (amount_points - 1)
156
+ theta = theta_start + direction * fract * theta_range
157
+ o = np.array([np.cos(theta), np.sin(theta)]) * r
158
+ output.append(centre + o)
159
+
160
+ return output
161
+
162
+
163
+ def approximate_linear(control_points: np.ndarray) -> list[np.ndarray]:
164
+ result = []
165
+
166
+ for c in control_points:
167
+ result.append(c.copy())
168
+
169
+ return result
170
+
171
+
172
+ def bezier_is_flat_enough(control_points: np.ndarray) -> bool:
173
+ for i in range(1, len(control_points) - 1):
174
+ p = control_points[i - 1] - 2 * control_points[i] + control_points[i + 1]
175
+ if length_squared(p) > BEZIER_TOLERANCE * BEZIER_TOLERANCE * 4:
176
+ return False
177
+
178
+ return True
179
+
180
+
181
+ def bezier_subdivide(
182
+ control_points: np.ndarray,
183
+ left: np.ndarray,
184
+ right: np.ndarray,
185
+ subdivision_buffer: np.ndarray,
186
+ count: int,
187
+ ) -> None:
188
+ midpoints = subdivision_buffer
189
+
190
+ for i in range(count):
191
+ midpoints[i] = control_points[i]
192
+
193
+ for i in range(count):
194
+ left[i] = midpoints[0].copy()
195
+ right[count - i - 1] = midpoints[count - i - 1]
196
+
197
+ for j in range(count - i - 1):
198
+ midpoints[j] = (midpoints[j] + midpoints[j + 1]) / 2
199
+
200
+
201
+ def bezier_approximate(
202
+ control_points: np.ndarray,
203
+ output: list[np.ndarray],
204
+ subdivision_buffer1: np.ndarray,
205
+ subdivision_buffer2: np.ndarray,
206
+ count: int,
207
+ ) -> None:
208
+ left = subdivision_buffer2
209
+ right = subdivision_buffer1
210
+
211
+ bezier_subdivide(control_points, left, right, subdivision_buffer1, count)
212
+
213
+ for i in range(count - 1):
214
+ left[count + i] = right[i + 1]
215
+
216
+ output.append(control_points[0].copy())
217
+
218
+ for i in range(1, count - 1):
219
+ index = 2 * i
220
+ p = 0.25 * (left[index - 1] + 2 * left[index] + left[index + 1])
221
+ output.append(p.copy())
222
+
223
+
224
+ def catmull_find_point(
225
+ vec1: np.ndarray,
226
+ vec2: np.ndarray,
227
+ vec3: np.ndarray,
228
+ vec4: np.ndarray,
229
+ t: float,
230
+ ) -> np.ndarray:
231
+ t2 = t * t
232
+ t3 = t * t2
233
+
234
+ result = np.array(
235
+ [
236
+ 0.5
237
+ * (
238
+ 2 * vec2[0]
239
+ + (-vec1[0] + vec3[0]) * t
240
+ + (2 * vec1[0] - 5 * vec2[0] + 4 * vec3[0] - vec4[0]) * t2
241
+ + (-vec1[0] + 3 * vec2[0] - 3 * vec3[0] + vec4[0]) * t3
242
+ ),
243
+ 0.5
244
+ * (
245
+ 2 * vec2[1]
246
+ + (-vec1[1] + vec3[1]) * t
247
+ + (2 * vec1[1] - 5 * vec2[1] + 4 * vec3[1] - vec4[1]) * t2
248
+ + (-vec1[1] + 3 * vec2[1] - 3 * vec3[1] + vec4[1]) * t3
249
+ ),
250
+ ],
251
+ )
252
+
253
+ return result
osuT5/inference/pipeline.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from slider import Beatmap
8
+ from tqdm import tqdm
9
+
10
+ from omegaconf import DictConfig
11
+
12
+ from osuT5.dataset import OsuParser
13
+ from osuT5.dataset.data_utils import update_event_times
14
+ from osuT5.tokenizer import Event, EventType, Tokenizer
15
+ from osuT5.model import OsuT
16
+
17
+ MILISECONDS_PER_SECOND = 1000
18
+ MILISECONDS_PER_STEP = 10
19
+
20
+ def top_k_sampling(logits, k):
21
+ top_k_logits, top_k_indices = torch.topk(logits, k)
22
+ top_k_probs = F.softmax(top_k_logits, dim=-1)
23
+ sampled_index = torch.multinomial(top_k_probs, 1)
24
+ sampled_token = top_k_indices.gather(-1, sampled_index)
25
+ return sampled_token
26
+
27
+ def preprocess_event(event, frame_time):
28
+ if event.type == EventType.TIME_SHIFT:
29
+ event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
30
+ return event
31
+
32
+ class Pipeline(object):
33
+ def __init__(self, args: DictConfig, tokenizer: Tokenizer):
34
+ """Model inference stage that processes sequences."""
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ self.tokenizer = tokenizer
37
+ self.tgt_seq_len = args.data.tgt_seq_len
38
+ self.frame_seq_len = args.data.src_seq_len - 1
39
+ self.frame_size = args.model.spectrogram.hop_length
40
+ self.sample_rate = args.model.spectrogram.sample_rate
41
+ self.samples_per_sequence = self.frame_seq_len * self.frame_size
42
+ self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride)
43
+ self.miliseconds_per_sequence = self.samples_per_sequence * MILISECONDS_PER_SECOND / self.sample_rate
44
+ self.miliseconds_per_stride = self.sequence_stride * MILISECONDS_PER_SECOND / self.sample_rate
45
+ self.beatmap_id = args.beatmap_id
46
+ self.difficulty = args.difficulty
47
+ self.center_pad_decoder = args.data.center_pad_decoder
48
+ self.special_token_len = args.data.special_token_len
49
+ self.diff_token_index = args.data.diff_token_index
50
+ self.style_token_index = args.data.style_token_index
51
+ self.max_pre_token_len = args.data.max_pre_token_len
52
+ self.add_pre_tokens = args.data.add_pre_tokens
53
+ self.add_gd_context = args.data.add_gd_context
54
+ self.bpm = args.bpm
55
+ self.offset = args.offset
56
+ self.total_duration_ms = args.total_duration_ms
57
+
58
+ print(f"Configuration: {args}")
59
+
60
+ if self.add_gd_context:
61
+ other_beatmap_path = Path(args.other_beatmap_path)
62
+
63
+ if not other_beatmap_path.is_file():
64
+ raise FileNotFoundError(f"Beatmap file {other_beatmap_path} not found.")
65
+
66
+ other_beatmap = Beatmap.from_path(other_beatmap_path)
67
+ self.other_beatmap_id = other_beatmap.beatmap_id
68
+ self.other_difficulty = float(other_beatmap.stars())
69
+ parser = OsuParser(tokenizer)
70
+ self.other_events = parser.parse(other_beatmap)
71
+ self.other_events, self.other_event_times = self._prepare_events(self.other_events)
72
+
73
+ def _calculate_time_shifts(self, bpm: float, duration_ms: float, tick_rate: int, offset: float = 0) -> list[float]:
74
+ """Calculate EventType.TIME_SHIFT events based on song's BPM and tick rate."""
75
+ events = []
76
+ ms_per_beat = 60000 / bpm # 60000 ms per minute
77
+ ms_per_tick = ms_per_beat / tick_rate
78
+ num_ticks = int(duration_ms // ms_per_tick)
79
+
80
+ for i in range(num_ticks):
81
+ events.append(float(int(i * ms_per_tick + offset)) )
82
+
83
+ return events
84
+
85
+ def generate_events(self, model, frames, tokens, encoder_outputs, beatmap_idx, total_steps):
86
+ temperature = 0.9
87
+ k = 10 # top-k sampling
88
+
89
+ for _ in range(total_steps):
90
+ out = model.forward(
91
+ frames=frames,
92
+ decoder_input_ids=tokens,
93
+ decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
94
+ encoder_outputs=encoder_outputs,
95
+ beatmap_idx=beatmap_idx,
96
+ )
97
+ encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
98
+ logits = out.logits
99
+ logits = logits[:, -1, :] / temperature
100
+ logits = self._filter(logits, 0.9)
101
+ probabilities = F.softmax(logits, dim=-1)
102
+ next_tokens = top_k_sampling(probabilities, k)
103
+
104
+ tokens = torch.cat([tokens, next_tokens], dim=-1)
105
+
106
+ eos_in_sentence = next_tokens == self.tokenizer.eos_id
107
+ if eos_in_sentence.all():
108
+ break
109
+
110
+ return tokens
111
+
112
+ def generate(self, model: OsuT, sequences: torch.Tensor, top_k: int = 50) -> list[Event]:
113
+ """
114
+ Generate a list of Event object lists and their timestamps given source sequences.
115
+
116
+ Args:
117
+ model: Trained model to use for inference.
118
+ sequences: A list of batched source sequences.
119
+ top_k: Number of top tokens to use for top-k sampling.
120
+
121
+ Returns:
122
+ events: List of Event object lists.
123
+ event_times: Corresponding event times of Event object lists in milliseconds.
124
+ """
125
+ events = []
126
+ event_times = []
127
+ temperature = 0.95
128
+
129
+ idx_dict = self.tokenizer.beatmap_idx
130
+ beatmap_idx = torch.tensor([idx_dict.get(self.beatmap_id, 6666)], dtype=torch.long, device=self.device)
131
+ style_token = self.tokenizer.encode_style(self.beatmap_id) if self.beatmap_id in idx_dict else self.tokenizer.style_unk
132
+ diff_token = self.tokenizer.encode_diff(self.difficulty) if self.difficulty != -1 else self.tokenizer.diff_unk
133
+
134
+ special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
135
+ special_tokens[:, self.diff_token_index] = diff_token
136
+ special_tokens[:, self.style_token_index] = style_token
137
+
138
+ if self.add_gd_context:
139
+ other_style_token = self.tokenizer.encode_style(self.other_beatmap_id) if self.other_beatmap_id in idx_dict else self.tokenizer.style_unk
140
+ other_special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
141
+ other_special_tokens[:, self.diff_token_index] = self.tokenizer.encode_diff(self.other_difficulty)
142
+ other_special_tokens[:, self.style_token_index] = other_style_token
143
+ else:
144
+ other_special_tokens = torch.empty((1, 0), dtype=torch.long, device=self.device)
145
+
146
+ for sequence_index, frames in enumerate(tqdm(sequences)):
147
+ # Get tokens of previous frame
148
+ frame_time = sequence_index * self.miliseconds_per_stride
149
+ prev_events = self._get_events_time_range(
150
+ events, event_times, frame_time - self.miliseconds_per_sequence, frame_time) if self.add_pre_tokens else []
151
+ post_events = self._get_events_time_range(
152
+ events, event_times, frame_time, frame_time + self.miliseconds_per_sequence)
153
+
154
+ prev_tokens = self._encode(prev_events, frame_time)
155
+ post_tokens = self._encode(post_events, frame_time)
156
+ post_token_length = post_tokens.shape[1]
157
+
158
+ if 0 <= self.max_pre_token_len < prev_tokens.shape[1]:
159
+ prev_tokens = prev_tokens[:, -self.max_pre_token_len:]
160
+
161
+ # Get prefix tokens
162
+ prefix = torch.cat([special_tokens, prev_tokens], dim=-1)
163
+ if self.center_pad_decoder:
164
+ prefix = F.pad(prefix, (self.tgt_seq_len // 2 - prefix.shape[1], 0), value=self.tokenizer.pad_id)
165
+ prefix_length = prefix.shape[1]
166
+
167
+
168
+ max_retries = 5
169
+ attempt = 0
170
+ result = []
171
+
172
+ while attempt < max_retries and not result:
173
+ attempt += 1
174
+ try:
175
+ # Reset tokens
176
+ tokens = torch.tensor([[self.tokenizer.sos_id]], dtype=torch.long, device=self.device)
177
+ tokens = torch.cat([prefix, tokens, post_tokens], dim=-1)
178
+
179
+ # Ensure frames are properly reset for each retry
180
+ retry_frames = frames.clone().to(self.device).unsqueeze(0)
181
+ encoder_outputs = None
182
+
183
+ while tokens.shape[-1] < self.tgt_seq_len:
184
+ out = model.forward(
185
+ frames=retry_frames,
186
+ decoder_input_ids=tokens,
187
+ decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
188
+ encoder_outputs=encoder_outputs,
189
+ #beatmap_idx=beatmap_idx,
190
+ )
191
+ encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
192
+
193
+ logits = out.logits[:, -1, :]
194
+ logits = logits / temperature
195
+ logits = self._filter(logits, top_p=0.9, top_k=60)
196
+ probabilities = F.softmax(logits, dim=-1)
197
+ next_tokens = torch.multinomial(probabilities, 1)
198
+
199
+ tokens = torch.cat([tokens, next_tokens], dim=-1)
200
+
201
+ eos_in_sentence = next_tokens == self.tokenizer.eos_id
202
+ if eos_in_sentence.all():
203
+ break
204
+
205
+ predicted_tokens = tokens[:, prefix_length + 1 + post_token_length:]
206
+ result = self._decode(predicted_tokens[0], frame_time)
207
+
208
+ # if no new combo in result, retry;
209
+ if len(result) > 10 and not any(event.type == EventType.NEW_COMBO for event in result):
210
+ #print("No new combo in result; retrying...")
211
+ result = []
212
+
213
+
214
+ except Exception as e:
215
+ #print(f"Attempt {attempt} encountered an error: {e}")
216
+ result = [] # Ensure result is empty to trigger retry
217
+
218
+ events += result
219
+
220
+ self._update_event_times(events, event_times, frame_time)
221
+
222
+ return events
223
+
224
+ def _prepare_events(self, events: list[Event]) -> tuple[list[Event], list[float]]:
225
+ """Pre-process raw list of events for inference. Calculates event times and removes redundant time shifts."""
226
+ ct = 0
227
+ event_times = []
228
+ for event in events:
229
+ if event.type == EventType.TIME_SHIFT:
230
+ ct = event.value
231
+ event_times.append(ct)
232
+
233
+ # Loop through the events in reverse to remove any time shifts that occur before anchor events
234
+ delete_next_time_shift = False
235
+ for i in range(len(events) - 1, -1, -1):
236
+ if events[i].type == EventType.TIME_SHIFT and delete_next_time_shift:
237
+ delete_next_time_shift = False
238
+ del events[i]
239
+ del event_times[i]
240
+ continue
241
+ elif events[i].type in [EventType.BEZIER_ANCHOR, EventType.PERFECT_ANCHOR, EventType.CATMULL_ANCHOR,
242
+ EventType.RED_ANCHOR]:
243
+ delete_next_time_shift = True
244
+
245
+ # duplicate events 3 times
246
+
247
+
248
+ return events, event_times
249
+
250
+ def _get_events_time_range(self, events: list[Event], event_times: list[float], start_time: float, end_time: float):
251
+ # Look from the end of the list
252
+ s = 0
253
+ for i in range(len(event_times) - 1, -1, -1):
254
+ if event_times[i] < start_time:
255
+ s = i + 1
256
+ break
257
+ e = 0
258
+ for i in range(len(event_times) - 1, -1, -1):
259
+ if event_times[i] < end_time:
260
+ e = i + 1
261
+ break
262
+ return events[s:e]
263
+
264
+ def _update_event_times(self, events: list[Event], event_times: list[float], frame_time: float):
265
+ update_event_times(events, event_times, frame_time + self.miliseconds_per_sequence)
266
+
267
+
268
+ def _encode(self, events: list[Event], frame_time: float) -> torch.Tensor:
269
+ try:
270
+
271
+ tokens = torch.empty((1, len(events)), dtype=torch.long)
272
+ for i, event in enumerate(events):
273
+ if event.type == EventType.TIME_SHIFT:
274
+ event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
275
+ tokens[0, i] = self.tokenizer.encode(event)
276
+ return tokens.to(self.device)
277
+ except Exception as e:
278
+ #print(f"Error encoding events: {events}")
279
+ #print(e)
280
+ return torch.empty((1, 0), dtype=torch.long, device=self.device)
281
+ def _decode(self, tokens: torch.Tensor, frame_time: float) -> list[Event]:
282
+ """Converts a list of tokens into Event objects and converts to absolute time values.
283
+
284
+ Args:
285
+ tokens: List of tokens.
286
+ frame time: Start time of current source sequence.
287
+
288
+ Returns:
289
+ events: List of Event objects.
290
+ """
291
+ events = []
292
+ for token in tokens:
293
+ if token == self.tokenizer.eos_id:
294
+ break
295
+
296
+ try:
297
+ event = self.tokenizer.decode(token.item())
298
+ except:
299
+ continue
300
+
301
+ if event.type == EventType.TIME_SHIFT:
302
+ event.value = frame_time + event.value * MILISECONDS_PER_STEP
303
+
304
+ events.append(event)
305
+
306
+ return events
307
+
308
+ def _filter(self, logits: torch.Tensor, top_p: float = 0.75, top_k: int = 1, filter_value: float = -float("Inf")) -> torch.Tensor:
309
+ """Filter a distribution of logits using nucleus (top-p) and/or top-k filtering.
310
+ """
311
+ logits = top_k_logits(logits, top_k) if top_k > 0 else logits
312
+
313
+ if 0.0 < top_p < 1.0:
314
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
315
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
316
+
317
+ sorted_indices_to_remove = cumulative_probs > top_p
318
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
319
+ sorted_indices_to_remove[..., 0] = 0
320
+
321
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
322
+ logits[indices_to_remove] = filter_value
323
+
324
+ return logits
325
+ def top_k_logits(logits, k):
326
+ """
327
+ Keep only the top-k tokens with highest probabilities.
328
+
329
+ Args:
330
+ logits: Logits distribution of shape (batch size, vocabulary size).
331
+ k: Number of top tokens to keep.
332
+
333
+ Returns:
334
+ logits with non-top-k elements set to negative infinity.
335
+ """
336
+ values, indices = torch.topk(logits, k)
337
+ min_values = values[:, -1].unsqueeze(-1).expand_as(logits)
338
+ return torch.where(logits < min_values, torch.full_like(logits, float("-Inf")), logits)
osuT5/inference/postprocessor.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import os
5
+ import pathlib
6
+ import uuid
7
+ from string import Template
8
+ import zipfile
9
+ import numpy as np
10
+ from omegaconf import DictConfig
11
+ import time as t
12
+ from osuT5.inference.slider_path import SliderPath
13
+ from osuT5.tokenizer import Event, EventType
14
+
15
+ OSU_FILE_EXTENSION = ".osu"
16
+ OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu")
17
+ STEPS_PER_MILLISECOND = 0.1
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class BeatmapConfig:
22
+ # General
23
+ audio_filename: str = ""
24
+
25
+ # Metadata
26
+ title: str = ""
27
+ title_unicode: str = ""
28
+ artist: str = ""
29
+ artist_unicode: str = ""
30
+ creator: str = ""
31
+ version: str = ""
32
+
33
+ # Difficulty
34
+ hp_drain_rate: float = 5
35
+ circle_size: float = 4
36
+ overall_difficulty: float = 8
37
+ approach_rate: float = 9
38
+ slider_multiplier: float = 1.8
39
+
40
+
41
+ def calculate_coordinates(last_pos, dist, num_samples, playfield_size):
42
+ # Generate a set of angles
43
+ angles = np.linspace(0, 2*np.pi, num_samples)
44
+
45
+ # Calculate the x and y coordinates for each angle
46
+ x_coords = last_pos[0] + dist * np.cos(angles)
47
+ y_coords = last_pos[1] + dist * np.sin(angles)
48
+
49
+ # Combine the x and y coordinates into a list of tuples
50
+ coordinates = list(zip(x_coords, y_coords))
51
+
52
+ # Filter out coordinates that are outside the playfield
53
+ coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]]
54
+
55
+ if len(coordinates) == 0:
56
+ return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)]
57
+
58
+ return coordinates
59
+
60
+
61
+ def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray:
62
+ eps = 1e-4
63
+ lr = 1
64
+ t = 1
65
+ for i in range(100):
66
+ grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm(
67
+ slider_path.position_at(t - eps) - pos,
68
+ )
69
+ t -= lr * grad
70
+
71
+ if grad == 0 or t < 0 or t > 1:
72
+ break
73
+
74
+ return np.clip(t, 0, 1)
75
+
76
+
77
+
78
+ def quantize_to_beat(time, bpm, offset):
79
+ """Quantize a given time to the nearest beat based on the BPM and offset."""
80
+ # tick rate is 1/4
81
+ #tick_rate = 0.25
82
+ # tick rate is 1/8
83
+ # tick_rate = 0.125
84
+ # tick rate is 1/2
85
+ #tick_rate = 0.5
86
+ tick_rate = 0.5
87
+ beats_per_minute = bpm
88
+ beats_per_second = beats_per_minute / 60.0
89
+ milliseconds_per_beat = 1000 / beats_per_second
90
+ quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
91
+ return quantized_time
92
+
93
+ def quantize_to_beat_again(time, bpm, offset):
94
+ """Quantize a given time to the nearest beat based on the BPM and offset."""
95
+ # tick rate is 1/4
96
+ #tick_rate = 0.25
97
+ # tick rate is 1/8
98
+ # tick_rate = 0.125
99
+ # tick rate is 1/2
100
+ #tick_rate = 0.5
101
+ tick_rate = 0.25
102
+ beats_per_minute = bpm
103
+ beats_per_second = beats_per_minute / 60.0
104
+ milliseconds_per_beat = 1000 / beats_per_second
105
+ quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
106
+ return quantized_time
107
+
108
+ def move_to_next_tick(time, bpm):
109
+ """Move to the next tick based on the BPM and offset."""
110
+ tick_rate = 0.25
111
+ beats_per_minute = bpm
112
+ beats_per_second = beats_per_minute / 60.0
113
+ milliseconds_per_beat = 1000 / beats_per_second
114
+ quantized_time = time + milliseconds_per_beat * tick_rate
115
+ return quantized_time
116
+
117
+ def move_to_prev_tick(time, bpm):
118
+ """Move to the next tick based on the BPM and offset."""
119
+ tick_rate = 0.25
120
+ beats_per_minute = bpm
121
+ beats_per_second = beats_per_minute / 60.0
122
+ milliseconds_per_beat = 1000 / beats_per_second
123
+ quantized_time = time - milliseconds_per_beat * tick_rate
124
+ return quantized_time
125
+
126
+ def adjust_hit_objects(hit_objects, bpm, offset):
127
+ """Adjust the timing of hit objects to align with beats based on BPM and offset."""
128
+ adjusted_hit_objects = []
129
+ adjusted_times = []
130
+ to_be_adjusted = []
131
+ for hit_object in hit_objects:
132
+ hit_type = hit_object.type
133
+ if hit_type == EventType.TIME_SHIFT:
134
+ time = quantize_to_beat(hit_object.value, bpm, offset)
135
+
136
+
137
+ if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END):
138
+ time = move_to_next_tick(time, bpm)
139
+ adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
140
+ adjusted_times.append(int(time))
141
+
142
+ else:
143
+ adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
144
+ adjusted_times.append(int(time))
145
+ else:
146
+ adjusted_hit_objects.append(hit_object)
147
+
148
+
149
+
150
+ return adjusted_hit_objects
151
+
152
+
153
+
154
+ class Postprocessor(object):
155
+ def __init__(self, args: DictConfig):
156
+ """Postprocessing stage that converts a list of Event objects to a beatmap file."""
157
+ self.curve_type_shorthand = {
158
+ "B": "Bezier",
159
+ "P": "PerfectCurve",
160
+ "C": "Catmull",
161
+ }
162
+
163
+ self.output_path = args.output_path
164
+ self.audio_path = args.audio_path
165
+ self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0]
166
+ self.beatmap_config = BeatmapConfig(
167
+ title=str(f"{self.audio_filename} ({args.title})"),
168
+ artist=str(args.artist),
169
+ title_unicode=str(args.title),
170
+ artist_unicode=str(args.artist),
171
+ audio_filename=pathlib.Path(args.audio_path).name,
172
+ slider_multiplier=float(args.slider_multiplier),
173
+ creator=str(args.creator),
174
+ version=str(args.version),
175
+ )
176
+ self.offset = args.offset
177
+ self.beat_length = 60000 / args.bpm
178
+ self.slider_multiplier = self.beatmap_config.slider_multiplier
179
+ self.bpm = args.bpm
180
+ self.resnap_objects = args.resnap_objects
181
+
182
+ def generate(self, generated_positions: list[Event]):
183
+ """Generate a beatmap file.
184
+
185
+ Args:
186
+ events: List of Event objects.
187
+
188
+ Returns:
189
+ None. An .osu file will be generated.
190
+ """
191
+ processed_events = []
192
+
193
+ for events in generated_positions:
194
+ # adjust hit objects to align with 1/4 beats
195
+ if self.resnap_objects:
196
+ events = adjust_hit_objects(events, self.bpm, self.offset)
197
+
198
+
199
+ hit_object_strings = []
200
+ time = 0
201
+ dist = 0
202
+ x = 256
203
+ y = 192
204
+ has_pos = False
205
+ new_combo = 0
206
+ ho_info = []
207
+ anchor_info = []
208
+
209
+ timing_point_strings = [
210
+ f"{self.offset},{self.beat_length},4,2,0,100,1,0"
211
+ ]
212
+
213
+ for event in events:
214
+ hit_type = event.type
215
+
216
+ if hit_type == EventType.TIME_SHIFT:
217
+ time = event.value
218
+ continue
219
+ elif hit_type == EventType.DISTANCE:
220
+ # Find a point which is dist away from the last point but still within the playfield
221
+ dist = event.value
222
+ coordinates = calculate_coordinates((x, y), dist, 500, (512, 384))
223
+ pos = coordinates[np.random.randint(len(coordinates))]
224
+ x, y = pos
225
+ continue
226
+ elif hit_type == EventType.POS_X:
227
+ x = event.value
228
+ has_pos = True
229
+ continue
230
+ elif hit_type == EventType.POS_Y:
231
+ y = event.value
232
+ has_pos = True
233
+ continue
234
+ elif hit_type == EventType.NEW_COMBO:
235
+ new_combo = 4
236
+ continue
237
+
238
+ if hit_type == EventType.CIRCLE:
239
+ hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0")
240
+ ho_info = []
241
+
242
+ elif hit_type == EventType.SPINNER:
243
+ ho_info = [time, new_combo]
244
+
245
+ elif hit_type == EventType.SPINNER_END and len(ho_info) == 2:
246
+ hit_object_strings.append(
247
+ f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}"
248
+ )
249
+ ho_info = []
250
+
251
+ elif hit_type == EventType.SLIDER_HEAD:
252
+ ho_info = [x, y, time, new_combo]
253
+ anchor_info = []
254
+
255
+ elif hit_type == EventType.BEZIER_ANCHOR:
256
+ anchor_info.append(('B', x, y))
257
+
258
+ elif hit_type == EventType.PERFECT_ANCHOR:
259
+ anchor_info.append(('P', x, y))
260
+
261
+ elif hit_type == EventType.CATMULL_ANCHOR:
262
+ anchor_info.append(('C', x, y))
263
+
264
+ elif hit_type == EventType.RED_ANCHOR:
265
+ anchor_info.append(('B', x, y))
266
+ anchor_info.append(('B', x, y))
267
+
268
+ elif hit_type == EventType.LAST_ANCHOR:
269
+ ho_info.append(time)
270
+ anchor_info.append(('B', x, y))
271
+
272
+ elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0:
273
+ curve_type = anchor_info[0][0]
274
+ span_duration = ho_info[4] - ho_info[2]
275
+ total_duration = time - ho_info[2]
276
+
277
+ if total_duration == 0 or span_duration == 0:
278
+ continue
279
+
280
+ slides = max(int(round(total_duration / span_duration)), 1)
281
+ control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info)
282
+ slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float))
283
+ length = slider_path.get_distance()
284
+
285
+ req_length = length * position_to_progress(
286
+ slider_path,
287
+ np.array((x, y)),
288
+ ) if has_pos else length - dist
289
+
290
+ if req_length < 1e-4:
291
+ continue
292
+
293
+ hit_object_strings.append(
294
+ f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}"
295
+ )
296
+
297
+ sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000
298
+ timing_point_strings.append(
299
+ f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0"
300
+ )
301
+
302
+ new_combo = 0
303
+
304
+ # Write .osu file
305
+ with open(OSU_TEMPLATE_PATH, "r") as tf:
306
+ template = Template(tf.read())
307
+ hit_objects = {"hit_objects": "\n".join(hit_object_strings)}
308
+ timing_points = {"timing_points": "\n".join(timing_point_strings)}
309
+ beatmap_config = dataclasses.asdict(self.beatmap_config)
310
+ result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points})
311
+ processed_events.append(result)
312
+
313
+ osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz")
314
+ with zipfile.ZipFile(osz_path, "w") as z:
315
+ for i, event in enumerate(processed_events):
316
+ osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}")
317
+ with open(osu_path, "w") as osu_file:
318
+ osu_file.write(event)
319
+ z.write(osu_path, os.path.basename(osu_path))
320
+ z.write(self.audio_path, os.path.basename(self.audio_path))
321
+ print(f"Mapset saved {osz_path}")
322
+ z.close()
osuT5/inference/preprocessor.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ from omegaconf import DictConfig
9
+
10
+ from osuT5.dataset.data_utils import load_audio_file
11
+
12
+
13
+ class Preprocessor(object):
14
+ def __init__(self, args: DictConfig):
15
+ """Preprocess audio data into sequences."""
16
+ self.frame_seq_len = args.data.src_seq_len - 1
17
+ self.frame_size = args.data.hop_length
18
+ self.sample_rate = args.data.sample_rate
19
+ self.samples_per_sequence = self.frame_seq_len * self.frame_size
20
+ self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride)
21
+
22
+ def load(self, path: Path) -> npt.ArrayLike:
23
+ """Load an audio file as audio frames. Convert stereo to mono, normalize.
24
+
25
+ Args:
26
+ path: Path to audio file.
27
+
28
+ Returns:
29
+ samples: Audio time-series.
30
+ """
31
+ return load_audio_file(path, self.sample_rate)
32
+
33
+ def segment(self, samples: npt.ArrayLike) -> torch.Tensor:
34
+ """Segment audio samples into sequences. Sequences are flattened frames.
35
+
36
+ Args:
37
+ samples: Audio time-series.
38
+
39
+ Returns:
40
+ sequences: A list of sequences of shape (batch size, samples per sequence).
41
+ """
42
+ samples = np.pad(
43
+ samples,
44
+ [0, self.sequence_stride - (len(samples) - self.samples_per_sequence) % self.sequence_stride],
45
+ )
46
+ sequences = self.window(samples, self.samples_per_sequence, self.sequence_stride)
47
+ sequences = torch.from_numpy(sequences).to(torch.float32)
48
+ return sequences
49
+
50
+ @staticmethod
51
+ def window(a, w, o, copy=False):
52
+ sh = (a.size - w + 1, w)
53
+ st = a.strides * 2
54
+ view = np.lib.stride_tricks.as_strided(a, strides=st, shape=sh)[0::o]
55
+ if copy:
56
+ return view.copy()
57
+ else:
58
+ return view
osuT5/inference/slider_path.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ from numpy.linalg import norm
5
+
6
+ import osuT5.inference.path_approximator as path_approximator
7
+
8
+
9
+ def binary_search(array, target):
10
+ lower = 0
11
+ upper = len(array)
12
+ while lower < upper: # use < instead of <=
13
+ x = lower + (upper - lower) // 2
14
+ val = array[x]
15
+ if target == val:
16
+ return x
17
+ elif target > val:
18
+ if lower == x: # these two are the actual lines
19
+ break # you're looking for
20
+ lower = x
21
+ elif target < val:
22
+ upper = x
23
+ return ~upper
24
+
25
+
26
+ class SliderPath:
27
+ __slots__ = (
28
+ "control_points",
29
+ "path_type",
30
+ "expected_distance",
31
+ "calculated_path",
32
+ "cumulative_length",
33
+ "is_initialised",
34
+ )
35
+
36
+ def __init__(
37
+ self,
38
+ path_type: str,
39
+ control_points: np.array,
40
+ expected_distance: float | None = None,
41
+ ) -> None:
42
+ self.control_points = control_points
43
+ self.path_type = path_type
44
+ self.expected_distance = expected_distance
45
+
46
+ self.calculated_path = None
47
+ self.cumulative_length = None
48
+
49
+ self.is_initialised = None
50
+
51
+ self.ensure_initialised()
52
+
53
+ def get_control_points(self) -> np.array:
54
+ self.ensure_initialised()
55
+ return self.control_points
56
+
57
+ def get_distance(self) -> float:
58
+ self.ensure_initialised()
59
+ return 0 if len(self.cumulative_length) == 0 else self.cumulative_length[-1]
60
+
61
+ def get_path_to_progress(self, path, p0, p1) -> None:
62
+ self.ensure_initialised()
63
+
64
+ d0 = self.progress_to_distance(p0)
65
+ d1 = self.progress_to_distance(p1)
66
+
67
+ path.clear()
68
+
69
+ i = 0
70
+ while i < len(self.calculated_path) and self.cumulative_length[i] < d0:
71
+ i += 1
72
+
73
+ path.append(self.interpolate_vertices(i, d0))
74
+
75
+ while i < len(self.calculated_path) and self.cumulative_length[i] < d1:
76
+ path.append(self.calculated_path[i])
77
+ i += 1
78
+
79
+ path.append(self.interpolate_vertices(i, d1))
80
+
81
+ def position_at(self, progress) -> np.array:
82
+ self.ensure_initialised()
83
+
84
+ d = self.progress_to_distance(progress)
85
+ return self.interpolate_vertices(self.index_of_distance(d), d)
86
+
87
+ def ensure_initialised(self) -> None:
88
+ if self.is_initialised:
89
+ return
90
+ self.is_initialised = True
91
+
92
+ self.control_points = [] if self.control_points is None else self.control_points
93
+ self.calculated_path = []
94
+ self.cumulative_length = []
95
+
96
+ self.calculate_path()
97
+ self.calculate_cumulative_length()
98
+
99
+ def calculate_subpath(self, sub_control_points) -> list:
100
+ if self.path_type == "Linear":
101
+ return path_approximator.approximate_linear(sub_control_points)
102
+ elif self.path_type == "PerfectCurve":
103
+ if len(self.get_control_points()) != 3 or len(sub_control_points) != 3:
104
+ return path_approximator.approximate_bezier(sub_control_points)
105
+
106
+ subpath = path_approximator.approximate_circular_arc(sub_control_points)
107
+
108
+ if len(subpath) == 0:
109
+ return path_approximator.approximate_bezier(sub_control_points)
110
+
111
+ return subpath
112
+ elif self.path_type == "Catmull":
113
+ return path_approximator.approximate_catmull(sub_control_points)
114
+ else:
115
+ return path_approximator.approximate_bezier(sub_control_points)
116
+
117
+ def calculate_path(self) -> None:
118
+ self.calculated_path.clear()
119
+
120
+ start = 0
121
+ end = 0
122
+
123
+ for i in range(len(self.get_control_points())):
124
+ end += 1
125
+
126
+ if (
127
+ i == len(self.get_control_points()) - 1
128
+ or (
129
+ self.get_control_points()[i] == self.get_control_points()[i + 1]
130
+ ).all()
131
+ ):
132
+ cp_span = self.get_control_points()[start:end]
133
+
134
+ for t in self.calculate_subpath(cp_span):
135
+ if (
136
+ len(self.calculated_path) == 0
137
+ or (self.calculated_path[-1] != t).any()
138
+ ):
139
+ self.calculated_path.append(t)
140
+
141
+ start = end
142
+
143
+ def calculate_cumulative_length(self) -> None:
144
+ length = 0
145
+
146
+ self.cumulative_length.clear()
147
+ self.cumulative_length.append(length)
148
+
149
+ for i in range(len(self.calculated_path) - 1):
150
+ diff = self.calculated_path[i + 1] - self.calculated_path[i]
151
+ d = norm(diff)
152
+
153
+ if (
154
+ self.expected_distance is not None
155
+ and self.expected_distance - length < d
156
+ ):
157
+ self.calculated_path[i + 1] = (
158
+ self.calculated_path[i]
159
+ + diff * (self.expected_distance - length) / d
160
+ )
161
+ del self.calculated_path[i + 2 : len(self.calculated_path) - 2 - i]
162
+
163
+ length = self.expected_distance
164
+ self.cumulative_length.append(length)
165
+ break
166
+
167
+ length += d
168
+ self.cumulative_length.append(length)
169
+
170
+ if (
171
+ self.expected_distance is not None
172
+ and length < self.expected_distance
173
+ and len(self.calculated_path) > 1
174
+ ):
175
+ diff = self.calculated_path[-1] - self.calculated_path[-2]
176
+ d = norm(diff)
177
+
178
+ if d <= 0:
179
+ return
180
+
181
+ self.calculated_path[-1] += (
182
+ diff * (self.expected_distance - self.cumulative_length[-1]) / d
183
+ )
184
+ self.cumulative_length[-1] = self.expected_distance
185
+
186
+ def index_of_distance(self, d) -> int:
187
+ i = binary_search(self.cumulative_length, d)
188
+ if i < 0:
189
+ i = ~i
190
+
191
+ return i
192
+
193
+ def progress_to_distance(self, progress) -> float:
194
+ return np.clip(progress, 0, 1) * self.get_distance()
195
+
196
+ def interpolate_vertices(self, i, d) -> np.array:
197
+ if len(self.calculated_path) == 0:
198
+ return np.zeros([2])
199
+
200
+ if i <= 0:
201
+ return self.calculated_path[0]
202
+ if i >= len(self.calculated_path):
203
+ return self.calculated_path[-1]
204
+
205
+ p0 = self.calculated_path[i - 1]
206
+ p1 = self.calculated_path[i]
207
+
208
+ d0 = self.cumulative_length[i - 1]
209
+ d1 = self.cumulative_length[i]
210
+
211
+ if np.isclose(d0, d1):
212
+ return p0
213
+
214
+ w = (d - d0) / (d1 - d0)
215
+ return p0 + (p1 - p0) * w
216
+
217
+
218
+ if __name__ == "__main__":
219
+ path = SliderPath(
220
+ "Bezier",
221
+ 100 * np.array([[0, 0], [1, 1], [1, -1], [2, 0], [2, 0], [3, -1], [2, -2]]),
222
+ )
223
+ p = np.vstack(path.calculated_path)
224
+ logging.info(p.shape)
225
+
226
+ import matplotlib.pyplot as plt
227
+
228
+ plt.axis("equal")
229
+ plt.plot(p[:, 0], p[:, 1], color="green")
230
+ plt.show()
osuT5/inference/template.osu ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ osu file format v14
2
+
3
+ [General]
4
+ AudioFilename: $audio_filename
5
+ AudioLeadIn: 0
6
+ PreviewTime: -1
7
+ Countdown: 0
8
+ SampleSet: Soft
9
+ StackLeniency: 0.7
10
+ Mode: 0
11
+ LetterboxInBreaks: 0
12
+ WidescreenStoryboard: 1
13
+
14
+ [Editor]
15
+ DistanceSpacing: 1.0
16
+ BeatDivisor: 4
17
+ GridSize: 8
18
+ TimelineZoom: 1
19
+
20
+ [Metadata]
21
+ Title:$title
22
+ TitleUnicode:$title_unicode
23
+ Artist:$artist
24
+ ArtistUnicode:$artist_unicode
25
+ Creator:$creator
26
+ Version:$version
27
+ Source:
28
+ Tags:
29
+
30
+ [Difficulty]
31
+ HPDrainRate:$hp_drain_rate
32
+ CircleSize:$circle_size
33
+ OverallDifficulty:$overall_difficulty
34
+ ApproachRate:$approach_rate
35
+ SliderMultiplier:$slider_multiplier
36
+ SliderTickRate:1
37
+
38
+ [Events]
39
+ //Background and Video events
40
+ //Break Periods
41
+ //Storyboard Layer 0 (Background)
42
+ //Storyboard Layer 1 (Fail)
43
+ //Storyboard Layer 2 (Pass)
44
+ //Storyboard Layer 3 (Foreground)
45
+ //Storyboard Layer 4 (Overlay)
46
+ //Storyboard Sound Samples
47
+
48
+ [TimingPoints]
49
+ $timing_points
50
+
51
+ [Colours]
52
+
53
+ [HitObjects]
54
+ $hit_objects
osuT5/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .osu_t import OsuT
osuT5/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (203 Bytes). View file
 
osuT5/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (171 Bytes). View file