File size: 12,247 Bytes
7ef7abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
from __future__ import annotations

import dataclasses
import os
import pathlib
import uuid
from string import Template
import zipfile
import numpy as np
from omegaconf import DictConfig
import time as t
from osuT5.inference.slider_path import SliderPath
from osuT5.tokenizer import Event, EventType

OSU_FILE_EXTENSION = ".osu"
OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu")
STEPS_PER_MILLISECOND = 0.1


@dataclasses.dataclass
class BeatmapConfig:
    # General
    audio_filename: str = ""

    # Metadata
    title: str = ""
    title_unicode: str = ""
    artist: str = ""
    artist_unicode: str = ""
    creator: str = ""
    version: str = ""

    # Difficulty
    hp_drain_rate: float = 5
    circle_size: float = 4
    overall_difficulty: float = 8
    approach_rate: float = 9
    slider_multiplier: float = 1.8


def calculate_coordinates(last_pos, dist, num_samples, playfield_size):
    # Generate a set of angles
    angles = np.linspace(0, 2*np.pi, num_samples)

    # Calculate the x and y coordinates for each angle
    x_coords = last_pos[0] + dist * np.cos(angles)
    y_coords = last_pos[1] + dist * np.sin(angles)

    # Combine the x and y coordinates into a list of tuples
    coordinates = list(zip(x_coords, y_coords))

    # Filter out coordinates that are outside the playfield
    coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]]

    if len(coordinates) == 0:
        return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)]

    return coordinates


def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray:
    eps = 1e-4
    lr = 1
    t = 1
    for i in range(100):
        grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm(
            slider_path.position_at(t - eps) - pos,
        )
        t -= lr * grad

        if grad == 0 or t < 0 or t > 1:
            break

    return np.clip(t, 0, 1)



def quantize_to_beat(time, bpm, offset):
    """Quantize a given time to the nearest  beat based on the BPM and offset."""
    # tick rate is 1/4
    #tick_rate = 0.25
    # tick rate is 1/8
    # tick_rate = 0.125
    # tick rate is 1/2
    #tick_rate = 0.5
    tick_rate = 0.5
    beats_per_minute = bpm
    beats_per_second = beats_per_minute / 60.0
    milliseconds_per_beat = 1000 / beats_per_second
    quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
    return quantized_time

def quantize_to_beat_again(time, bpm, offset):
    """Quantize a given time to the nearest  beat based on the BPM and offset."""
    # tick rate is 1/4
    #tick_rate = 0.25
    # tick rate is 1/8
    # tick_rate = 0.125
    # tick rate is 1/2
    #tick_rate = 0.5
    tick_rate = 0.25
    beats_per_minute = bpm
    beats_per_second = beats_per_minute / 60.0
    milliseconds_per_beat = 1000 / beats_per_second
    quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
    return quantized_time

def move_to_next_tick(time, bpm):
    """Move to the next tick based on the BPM and offset."""
    tick_rate = 0.25
    beats_per_minute = bpm
    beats_per_second = beats_per_minute / 60.0
    milliseconds_per_beat = 1000 / beats_per_second
    quantized_time = time + milliseconds_per_beat * tick_rate
    return quantized_time

def move_to_prev_tick(time, bpm):
    """Move to the next tick based on the BPM and offset."""
    tick_rate = 0.25
    beats_per_minute = bpm
    beats_per_second = beats_per_minute / 60.0
    milliseconds_per_beat = 1000 / beats_per_second
    quantized_time = time - milliseconds_per_beat * tick_rate
    return quantized_time

def adjust_hit_objects(hit_objects, bpm, offset):
    """Adjust the timing of hit objects to align with beats based on BPM and offset."""
    adjusted_hit_objects = []
    adjusted_times = []
    to_be_adjusted = []
    for hit_object in hit_objects:
        hit_type = hit_object.type
        if hit_type == EventType.TIME_SHIFT:
            time = quantize_to_beat(hit_object.value, bpm, offset)

                
            if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END):
                time = move_to_next_tick(time, bpm)
                adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
                adjusted_times.append(int(time))
            
            else:
                adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
                adjusted_times.append(int(time))
        else:
            adjusted_hit_objects.append(hit_object)


    
    return adjusted_hit_objects



class Postprocessor(object):
    def __init__(self, args: DictConfig):
        """Postprocessing stage that converts a list of Event objects to a beatmap file."""
        self.curve_type_shorthand = {
            "B": "Bezier",
            "P": "PerfectCurve",
            "C": "Catmull",
        }

        self.output_path = args.output_path
        self.audio_path = args.audio_path
        self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0]
        self.beatmap_config = BeatmapConfig(
            title=str(f"{self.audio_filename} ({args.title})"),
            artist=str(args.artist),
            title_unicode=str(args.title),
            artist_unicode=str(args.artist),
            audio_filename=pathlib.Path(args.audio_path).name,
            slider_multiplier=float(args.slider_multiplier),
            creator=str(args.creator),
            version=str(args.version),
        )
        self.offset = args.offset
        self.beat_length = 60000 / args.bpm
        self.slider_multiplier = self.beatmap_config.slider_multiplier
        self.bpm = args.bpm
        self.resnap_objects = args.resnap_objects

    def generate(self, generated_positions: list[Event]):
        """Generate a beatmap file.



        Args:

            events: List of Event objects.



        Returns:

            None. An .osu file will be generated.

        """
        processed_events = []
        
        for events in generated_positions:
            # adjust hit objects to align with 1/4 beats
            if self.resnap_objects:
               events = adjust_hit_objects(events, self.bpm, self.offset)
    
    
            hit_object_strings = []
            time = 0
            dist = 0
            x = 256
            y = 192
            has_pos = False
            new_combo = 0
            ho_info = []
            anchor_info = []

            timing_point_strings = [
                f"{self.offset},{self.beat_length},4,2,0,100,1,0"
            ]
            
            for event in events:
                hit_type = event.type

                if hit_type == EventType.TIME_SHIFT:
                    time = event.value
                    continue
                elif hit_type == EventType.DISTANCE:
                    # Find a point which is dist away from the last point but still within the playfield
                    dist = event.value
                    coordinates = calculate_coordinates((x, y), dist, 500, (512, 384))
                    pos = coordinates[np.random.randint(len(coordinates))]
                    x, y = pos
                    continue
                elif hit_type == EventType.POS_X:
                    x = event.value
                    has_pos = True
                    continue
                elif hit_type == EventType.POS_Y:
                    y = event.value
                    has_pos = True
                    continue
                elif hit_type == EventType.NEW_COMBO:
                    new_combo = 4
                    continue

                if hit_type == EventType.CIRCLE:
                    hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0")
                    ho_info = []

                elif hit_type == EventType.SPINNER:
                    ho_info = [time, new_combo]

                elif hit_type == EventType.SPINNER_END and len(ho_info) == 2:
                    hit_object_strings.append(
                        f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}"
                    )
                    ho_info = []

                elif hit_type == EventType.SLIDER_HEAD:
                    ho_info = [x, y, time, new_combo]
                    anchor_info = []

                elif hit_type == EventType.BEZIER_ANCHOR:
                    anchor_info.append(('B', x, y))

                elif hit_type == EventType.PERFECT_ANCHOR:
                    anchor_info.append(('P', x, y))

                elif hit_type == EventType.CATMULL_ANCHOR:
                    anchor_info.append(('C', x, y))

                elif hit_type == EventType.RED_ANCHOR:
                    anchor_info.append(('B', x, y))
                    anchor_info.append(('B', x, y))

                elif hit_type == EventType.LAST_ANCHOR:
                    ho_info.append(time)
                    anchor_info.append(('B', x, y))

                elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0:
                    curve_type = anchor_info[0][0]
                    span_duration = ho_info[4] - ho_info[2]
                    total_duration = time - ho_info[2]

                    if total_duration == 0 or span_duration == 0:
                        continue

                    slides = max(int(round(total_duration / span_duration)), 1)
                    control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info)
                    slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float))
                    length = slider_path.get_distance()

                    req_length = length * position_to_progress(
                        slider_path,
                        np.array((x, y)),
                    ) if has_pos else length - dist

                    if req_length < 1e-4:
                        continue

                    hit_object_strings.append(
                        f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}"
                    )

                    sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000
                    timing_point_strings.append(
                        f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0"
                    )

                new_combo = 0
    
            # Write .osu file
            with open(OSU_TEMPLATE_PATH, "r") as tf:
                template = Template(tf.read())
                hit_objects = {"hit_objects": "\n".join(hit_object_strings)}
                timing_points = {"timing_points": "\n".join(timing_point_strings)}
                beatmap_config = dataclasses.asdict(self.beatmap_config)
                result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points})
                processed_events.append(result)

        osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz")
        with zipfile.ZipFile(osz_path, "w") as z:
            for i, event in enumerate(processed_events):
                osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}")
                with open(osu_path, "w") as osu_file:
                    osu_file.write(event)
                z.write(osu_path, os.path.basename(osu_path))
            z.write(self.audio_path, os.path.basename(self.audio_path))
            print(f"Mapset saved {osz_path}")
            z.close()