from __future__ import annotations import dataclasses import os import pathlib import uuid from string import Template import zipfile import numpy as np from omegaconf import DictConfig import time as t from osuT5.inference.slider_path import SliderPath from osuT5.tokenizer import Event, EventType OSU_FILE_EXTENSION = ".osu" OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu") STEPS_PER_MILLISECOND = 0.1 @dataclasses.dataclass class BeatmapConfig: # General audio_filename: str = "" # Metadata title: str = "" title_unicode: str = "" artist: str = "" artist_unicode: str = "" creator: str = "" version: str = "" # Difficulty hp_drain_rate: float = 5 circle_size: float = 4 overall_difficulty: float = 8 approach_rate: float = 9 slider_multiplier: float = 1.8 def calculate_coordinates(last_pos, dist, num_samples, playfield_size): # Generate a set of angles angles = np.linspace(0, 2*np.pi, num_samples) # Calculate the x and y coordinates for each angle x_coords = last_pos[0] + dist * np.cos(angles) y_coords = last_pos[1] + dist * np.sin(angles) # Combine the x and y coordinates into a list of tuples coordinates = list(zip(x_coords, y_coords)) # Filter out coordinates that are outside the playfield coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]] if len(coordinates) == 0: return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)] return coordinates def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray: eps = 1e-4 lr = 1 t = 1 for i in range(100): grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm( slider_path.position_at(t - eps) - pos, ) t -= lr * grad if grad == 0 or t < 0 or t > 1: break return np.clip(t, 0, 1) def quantize_to_beat(time, bpm, offset): """Quantize a given time to the nearest beat based on the BPM and offset.""" # tick rate is 1/4 #tick_rate = 0.25 # tick rate is 1/8 # tick_rate = 0.125 # tick rate is 1/2 #tick_rate = 0.5 tick_rate = 0.5 beats_per_minute = bpm beats_per_second = beats_per_minute / 60.0 milliseconds_per_beat = 1000 / beats_per_second quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset return quantized_time def quantize_to_beat_again(time, bpm, offset): """Quantize a given time to the nearest beat based on the BPM and offset.""" # tick rate is 1/4 #tick_rate = 0.25 # tick rate is 1/8 # tick_rate = 0.125 # tick rate is 1/2 #tick_rate = 0.5 tick_rate = 0.25 beats_per_minute = bpm beats_per_second = beats_per_minute / 60.0 milliseconds_per_beat = 1000 / beats_per_second quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset return quantized_time def move_to_next_tick(time, bpm): """Move to the next tick based on the BPM and offset.""" tick_rate = 0.25 beats_per_minute = bpm beats_per_second = beats_per_minute / 60.0 milliseconds_per_beat = 1000 / beats_per_second quantized_time = time + milliseconds_per_beat * tick_rate return quantized_time def move_to_prev_tick(time, bpm): """Move to the next tick based on the BPM and offset.""" tick_rate = 0.25 beats_per_minute = bpm beats_per_second = beats_per_minute / 60.0 milliseconds_per_beat = 1000 / beats_per_second quantized_time = time - milliseconds_per_beat * tick_rate return quantized_time def adjust_hit_objects(hit_objects, bpm, offset): """Adjust the timing of hit objects to align with beats based on BPM and offset.""" adjusted_hit_objects = [] adjusted_times = [] to_be_adjusted = [] for hit_object in hit_objects: hit_type = hit_object.type if hit_type == EventType.TIME_SHIFT: time = quantize_to_beat(hit_object.value, bpm, offset) if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END): time = move_to_next_tick(time, bpm) adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time)) adjusted_times.append(int(time)) else: adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time)) adjusted_times.append(int(time)) else: adjusted_hit_objects.append(hit_object) return adjusted_hit_objects class Postprocessor(object): def __init__(self, args: DictConfig): """Postprocessing stage that converts a list of Event objects to a beatmap file.""" self.curve_type_shorthand = { "B": "Bezier", "P": "PerfectCurve", "C": "Catmull", } self.output_path = args.output_path self.audio_path = args.audio_path self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0] self.beatmap_config = BeatmapConfig( title=str(f"{self.audio_filename} ({args.title})"), artist=str(args.artist), title_unicode=str(args.title), artist_unicode=str(args.artist), audio_filename=pathlib.Path(args.audio_path).name, slider_multiplier=float(args.slider_multiplier), creator=str(args.creator), version=str(args.version), ) self.offset = args.offset self.beat_length = 60000 / args.bpm self.slider_multiplier = self.beatmap_config.slider_multiplier self.bpm = args.bpm self.resnap_objects = args.resnap_objects def generate(self, generated_positions: list[Event]): """Generate a beatmap file. Args: events: List of Event objects. Returns: None. An .osu file will be generated. """ processed_events = [] for events in generated_positions: # adjust hit objects to align with 1/4 beats if self.resnap_objects: events = adjust_hit_objects(events, self.bpm, self.offset) hit_object_strings = [] time = 0 dist = 0 x = 256 y = 192 has_pos = False new_combo = 0 ho_info = [] anchor_info = [] timing_point_strings = [ f"{self.offset},{self.beat_length},4,2,0,100,1,0" ] for event in events: hit_type = event.type if hit_type == EventType.TIME_SHIFT: time = event.value continue elif hit_type == EventType.DISTANCE: # Find a point which is dist away from the last point but still within the playfield dist = event.value coordinates = calculate_coordinates((x, y), dist, 500, (512, 384)) pos = coordinates[np.random.randint(len(coordinates))] x, y = pos continue elif hit_type == EventType.POS_X: x = event.value has_pos = True continue elif hit_type == EventType.POS_Y: y = event.value has_pos = True continue elif hit_type == EventType.NEW_COMBO: new_combo = 4 continue if hit_type == EventType.CIRCLE: hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0") ho_info = [] elif hit_type == EventType.SPINNER: ho_info = [time, new_combo] elif hit_type == EventType.SPINNER_END and len(ho_info) == 2: hit_object_strings.append( f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}" ) ho_info = [] elif hit_type == EventType.SLIDER_HEAD: ho_info = [x, y, time, new_combo] anchor_info = [] elif hit_type == EventType.BEZIER_ANCHOR: anchor_info.append(('B', x, y)) elif hit_type == EventType.PERFECT_ANCHOR: anchor_info.append(('P', x, y)) elif hit_type == EventType.CATMULL_ANCHOR: anchor_info.append(('C', x, y)) elif hit_type == EventType.RED_ANCHOR: anchor_info.append(('B', x, y)) anchor_info.append(('B', x, y)) elif hit_type == EventType.LAST_ANCHOR: ho_info.append(time) anchor_info.append(('B', x, y)) elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0: curve_type = anchor_info[0][0] span_duration = ho_info[4] - ho_info[2] total_duration = time - ho_info[2] if total_duration == 0 or span_duration == 0: continue slides = max(int(round(total_duration / span_duration)), 1) control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info) slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float)) length = slider_path.get_distance() req_length = length * position_to_progress( slider_path, np.array((x, y)), ) if has_pos else length - dist if req_length < 1e-4: continue hit_object_strings.append( f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}" ) sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000 timing_point_strings.append( f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0" ) new_combo = 0 # Write .osu file with open(OSU_TEMPLATE_PATH, "r") as tf: template = Template(tf.read()) hit_objects = {"hit_objects": "\n".join(hit_object_strings)} timing_points = {"timing_points": "\n".join(timing_point_strings)} beatmap_config = dataclasses.asdict(self.beatmap_config) result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points}) processed_events.append(result) osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz") with zipfile.ZipFile(osz_path, "w") as z: for i, event in enumerate(processed_events): osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}") with open(osu_path, "w") as osu_file: osu_file.write(event) z.write(osu_path, os.path.basename(osu_path)) z.write(self.audio_path, os.path.basename(self.audio_path)) print(f"Mapset saved {osz_path}") z.close()