| | from __future__ import annotations
|
| |
|
| | import dataclasses
|
| | import os
|
| | import pathlib
|
| | import uuid
|
| | from string import Template
|
| | import zipfile
|
| | import numpy as np
|
| | from omegaconf import DictConfig
|
| | import time as t
|
| | from osuT5.inference.slider_path import SliderPath
|
| | from osuT5.tokenizer import Event, EventType
|
| |
|
| | OSU_FILE_EXTENSION = ".osu"
|
| | OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu")
|
| | STEPS_PER_MILLISECOND = 0.1
|
| |
|
| |
|
| | @dataclasses.dataclass
|
| | class BeatmapConfig:
|
| |
|
| | audio_filename: str = ""
|
| |
|
| |
|
| | title: str = ""
|
| | title_unicode: str = ""
|
| | artist: str = ""
|
| | artist_unicode: str = ""
|
| | creator: str = ""
|
| | version: str = ""
|
| |
|
| |
|
| | hp_drain_rate: float = 5
|
| | circle_size: float = 4
|
| | overall_difficulty: float = 8
|
| | approach_rate: float = 9
|
| | slider_multiplier: float = 1.8
|
| |
|
| |
|
| | def calculate_coordinates(last_pos, dist, num_samples, playfield_size):
|
| |
|
| | angles = np.linspace(0, 2*np.pi, num_samples)
|
| |
|
| |
|
| | x_coords = last_pos[0] + dist * np.cos(angles)
|
| | y_coords = last_pos[1] + dist * np.sin(angles)
|
| |
|
| |
|
| | coordinates = list(zip(x_coords, y_coords))
|
| |
|
| |
|
| | coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]]
|
| |
|
| | if len(coordinates) == 0:
|
| | return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)]
|
| |
|
| | return coordinates
|
| |
|
| |
|
| | def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray:
|
| | eps = 1e-4
|
| | lr = 1
|
| | t = 1
|
| | for i in range(100):
|
| | grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm(
|
| | slider_path.position_at(t - eps) - pos,
|
| | )
|
| | t -= lr * grad
|
| |
|
| | if grad == 0 or t < 0 or t > 1:
|
| | break
|
| |
|
| | return np.clip(t, 0, 1)
|
| |
|
| |
|
| |
|
| | def quantize_to_beat(time, bpm, offset):
|
| | """Quantize a given time to the nearest beat based on the BPM and offset."""
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | tick_rate = 0.5
|
| | beats_per_minute = bpm
|
| | beats_per_second = beats_per_minute / 60.0
|
| | milliseconds_per_beat = 1000 / beats_per_second
|
| | quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
|
| | return quantized_time
|
| |
|
| | def quantize_to_beat_again(time, bpm, offset):
|
| | """Quantize a given time to the nearest beat based on the BPM and offset."""
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | tick_rate = 0.25
|
| | beats_per_minute = bpm
|
| | beats_per_second = beats_per_minute / 60.0
|
| | milliseconds_per_beat = 1000 / beats_per_second
|
| | quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
|
| | return quantized_time
|
| |
|
| | def move_to_next_tick(time, bpm):
|
| | """Move to the next tick based on the BPM and offset."""
|
| | tick_rate = 0.25
|
| | beats_per_minute = bpm
|
| | beats_per_second = beats_per_minute / 60.0
|
| | milliseconds_per_beat = 1000 / beats_per_second
|
| | quantized_time = time + milliseconds_per_beat * tick_rate
|
| | return quantized_time
|
| |
|
| | def move_to_prev_tick(time, bpm):
|
| | """Move to the next tick based on the BPM and offset."""
|
| | tick_rate = 0.25
|
| | beats_per_minute = bpm
|
| | beats_per_second = beats_per_minute / 60.0
|
| | milliseconds_per_beat = 1000 / beats_per_second
|
| | quantized_time = time - milliseconds_per_beat * tick_rate
|
| | return quantized_time
|
| |
|
| | def adjust_hit_objects(hit_objects, bpm, offset):
|
| | """Adjust the timing of hit objects to align with beats based on BPM and offset."""
|
| | adjusted_hit_objects = []
|
| | adjusted_times = []
|
| | to_be_adjusted = []
|
| | for hit_object in hit_objects:
|
| | hit_type = hit_object.type
|
| | if hit_type == EventType.TIME_SHIFT:
|
| | time = quantize_to_beat(hit_object.value, bpm, offset)
|
| |
|
| |
|
| | if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END):
|
| | time = move_to_next_tick(time, bpm)
|
| | adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
|
| | adjusted_times.append(int(time))
|
| |
|
| | else:
|
| | adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
|
| | adjusted_times.append(int(time))
|
| | else:
|
| | adjusted_hit_objects.append(hit_object)
|
| |
|
| |
|
| |
|
| | return adjusted_hit_objects
|
| |
|
| |
|
| |
|
| | class Postprocessor(object):
|
| | def __init__(self, args: DictConfig):
|
| | """Postprocessing stage that converts a list of Event objects to a beatmap file."""
|
| | self.curve_type_shorthand = {
|
| | "B": "Bezier",
|
| | "P": "PerfectCurve",
|
| | "C": "Catmull",
|
| | }
|
| |
|
| | self.output_path = args.output_path
|
| | self.audio_path = args.audio_path
|
| | self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0]
|
| | self.beatmap_config = BeatmapConfig(
|
| | title=str(f"{self.audio_filename} ({args.title})"),
|
| | artist=str(args.artist),
|
| | title_unicode=str(args.title),
|
| | artist_unicode=str(args.artist),
|
| | audio_filename=pathlib.Path(args.audio_path).name,
|
| | slider_multiplier=float(args.slider_multiplier),
|
| | creator=str(args.creator),
|
| | version=str(args.version),
|
| | )
|
| | self.offset = args.offset
|
| | self.beat_length = 60000 / args.bpm
|
| | self.slider_multiplier = self.beatmap_config.slider_multiplier
|
| | self.bpm = args.bpm
|
| | self.resnap_objects = args.resnap_objects
|
| |
|
| | def generate(self, generated_positions: list[Event]):
|
| | """Generate a beatmap file.
|
| |
|
| | Args:
|
| | events: List of Event objects.
|
| |
|
| | Returns:
|
| | None. An .osu file will be generated.
|
| | """
|
| | processed_events = []
|
| |
|
| | for events in generated_positions:
|
| |
|
| | if self.resnap_objects:
|
| | events = adjust_hit_objects(events, self.bpm, self.offset)
|
| |
|
| |
|
| | hit_object_strings = []
|
| | time = 0
|
| | dist = 0
|
| | x = 256
|
| | y = 192
|
| | has_pos = False
|
| | new_combo = 0
|
| | ho_info = []
|
| | anchor_info = []
|
| |
|
| | timing_point_strings = [
|
| | f"{self.offset},{self.beat_length},4,2,0,100,1,0"
|
| | ]
|
| |
|
| | for event in events:
|
| | hit_type = event.type
|
| |
|
| | if hit_type == EventType.TIME_SHIFT:
|
| | time = event.value
|
| | continue
|
| | elif hit_type == EventType.DISTANCE:
|
| |
|
| | dist = event.value
|
| | coordinates = calculate_coordinates((x, y), dist, 500, (512, 384))
|
| | pos = coordinates[np.random.randint(len(coordinates))]
|
| | x, y = pos
|
| | continue
|
| | elif hit_type == EventType.POS_X:
|
| | x = event.value
|
| | has_pos = True
|
| | continue
|
| | elif hit_type == EventType.POS_Y:
|
| | y = event.value
|
| | has_pos = True
|
| | continue
|
| | elif hit_type == EventType.NEW_COMBO:
|
| | new_combo = 4
|
| | continue
|
| |
|
| | if hit_type == EventType.CIRCLE:
|
| | hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0")
|
| | ho_info = []
|
| |
|
| | elif hit_type == EventType.SPINNER:
|
| | ho_info = [time, new_combo]
|
| |
|
| | elif hit_type == EventType.SPINNER_END and len(ho_info) == 2:
|
| | hit_object_strings.append(
|
| | f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}"
|
| | )
|
| | ho_info = []
|
| |
|
| | elif hit_type == EventType.SLIDER_HEAD:
|
| | ho_info = [x, y, time, new_combo]
|
| | anchor_info = []
|
| |
|
| | elif hit_type == EventType.BEZIER_ANCHOR:
|
| | anchor_info.append(('B', x, y))
|
| |
|
| | elif hit_type == EventType.PERFECT_ANCHOR:
|
| | anchor_info.append(('P', x, y))
|
| |
|
| | elif hit_type == EventType.CATMULL_ANCHOR:
|
| | anchor_info.append(('C', x, y))
|
| |
|
| | elif hit_type == EventType.RED_ANCHOR:
|
| | anchor_info.append(('B', x, y))
|
| | anchor_info.append(('B', x, y))
|
| |
|
| | elif hit_type == EventType.LAST_ANCHOR:
|
| | ho_info.append(time)
|
| | anchor_info.append(('B', x, y))
|
| |
|
| | elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0:
|
| | curve_type = anchor_info[0][0]
|
| | span_duration = ho_info[4] - ho_info[2]
|
| | total_duration = time - ho_info[2]
|
| |
|
| | if total_duration == 0 or span_duration == 0:
|
| | continue
|
| |
|
| | slides = max(int(round(total_duration / span_duration)), 1)
|
| | control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info)
|
| | slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float))
|
| | length = slider_path.get_distance()
|
| |
|
| | req_length = length * position_to_progress(
|
| | slider_path,
|
| | np.array((x, y)),
|
| | ) if has_pos else length - dist
|
| |
|
| | if req_length < 1e-4:
|
| | continue
|
| |
|
| | hit_object_strings.append(
|
| | f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}"
|
| | )
|
| |
|
| | sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000
|
| | timing_point_strings.append(
|
| | f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0"
|
| | )
|
| |
|
| | new_combo = 0
|
| |
|
| |
|
| | with open(OSU_TEMPLATE_PATH, "r") as tf:
|
| | template = Template(tf.read())
|
| | hit_objects = {"hit_objects": "\n".join(hit_object_strings)}
|
| | timing_points = {"timing_points": "\n".join(timing_point_strings)}
|
| | beatmap_config = dataclasses.asdict(self.beatmap_config)
|
| | result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points})
|
| | processed_events.append(result)
|
| |
|
| | osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz")
|
| | with zipfile.ZipFile(osz_path, "w") as z:
|
| | for i, event in enumerate(processed_events):
|
| | osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}")
|
| | with open(osu_path, "w") as osu_file:
|
| | osu_file.write(event)
|
| | z.write(osu_path, os.path.basename(osu_path))
|
| | z.write(self.audio_path, os.path.basename(self.audio_path))
|
| | print(f"Mapset saved {osz_path}")
|
| | z.close() |