osu_mapper / osuT5 /inference /postprocessor.py
Tiger14n's picture
edit
cb8e42f
from __future__ import annotations
import dataclasses
import os
import pathlib
import uuid
from string import Template
import zipfile
import numpy as np
from omegaconf import DictConfig
import time as t
from osuT5.inference.slider_path import SliderPath
from osuT5.tokenizer import Event, EventType
OSU_FILE_EXTENSION = ".osu"
OSU_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template.osu")
STEPS_PER_MILLISECOND = 0.1
@dataclasses.dataclass
class BeatmapConfig:
# General
audio_filename: str = ""
# Metadata
title: str = ""
title_unicode: str = ""
artist: str = ""
artist_unicode: str = ""
creator: str = ""
version: str = ""
# Difficulty
hp_drain_rate: float = 5
circle_size: float = 4
overall_difficulty: float = 8
approach_rate: float = 9
slider_multiplier: float = 1.8
def calculate_coordinates(last_pos, dist, num_samples, playfield_size):
# Generate a set of angles
angles = np.linspace(0, 2*np.pi, num_samples)
# Calculate the x and y coordinates for each angle
x_coords = last_pos[0] + dist * np.cos(angles)
y_coords = last_pos[1] + dist * np.sin(angles)
# Combine the x and y coordinates into a list of tuples
coordinates = list(zip(x_coords, y_coords))
# Filter out coordinates that are outside the playfield
coordinates = [(x, y) for x, y in coordinates if 0 <= x <= playfield_size[0] and 0 <= y <= playfield_size[1]]
if len(coordinates) == 0:
return [playfield_size] if last_pos[0] + last_pos[1] > (playfield_size[0] + playfield_size[1]) / 2 else [(0, 0)]
return coordinates
def position_to_progress(slider_path: SliderPath, pos: np.ndarray) -> np.ndarray:
eps = 1e-4
lr = 1
t = 1
for i in range(100):
grad = np.linalg.norm(slider_path.position_at(t) - pos) - np.linalg.norm(
slider_path.position_at(t - eps) - pos,
)
t -= lr * grad
if grad == 0 or t < 0 or t > 1:
break
return np.clip(t, 0, 1)
def quantize_to_beat(time, bpm, offset):
"""Quantize a given time to the nearest beat based on the BPM and offset."""
# tick rate is 1/4
#tick_rate = 0.25
# tick rate is 1/8
# tick_rate = 0.125
# tick rate is 1/2
#tick_rate = 0.5
tick_rate = 0.5
beats_per_minute = bpm
beats_per_second = beats_per_minute / 60.0
milliseconds_per_beat = 1000 / beats_per_second
quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
return quantized_time
def quantize_to_beat_again(time, bpm, offset):
"""Quantize a given time to the nearest beat based on the BPM and offset."""
# tick rate is 1/4
#tick_rate = 0.25
# tick rate is 1/8
# tick_rate = 0.125
# tick rate is 1/2
#tick_rate = 0.5
tick_rate = 0.25
beats_per_minute = bpm
beats_per_second = beats_per_minute / 60.0
milliseconds_per_beat = 1000 / beats_per_second
quantized_time = round((time - offset) / (milliseconds_per_beat * tick_rate)) * (milliseconds_per_beat * tick_rate) + offset
return quantized_time
def move_to_next_tick(time, bpm):
"""Move to the next tick based on the BPM and offset."""
tick_rate = 0.25
beats_per_minute = bpm
beats_per_second = beats_per_minute / 60.0
milliseconds_per_beat = 1000 / beats_per_second
quantized_time = time + milliseconds_per_beat * tick_rate
return quantized_time
def move_to_prev_tick(time, bpm):
"""Move to the next tick based on the BPM and offset."""
tick_rate = 0.25
beats_per_minute = bpm
beats_per_second = beats_per_minute / 60.0
milliseconds_per_beat = 1000 / beats_per_second
quantized_time = time - milliseconds_per_beat * tick_rate
return quantized_time
def adjust_hit_objects(hit_objects, bpm, offset):
"""Adjust the timing of hit objects to align with beats based on BPM and offset."""
adjusted_hit_objects = []
adjusted_times = []
to_be_adjusted = []
for hit_object in hit_objects:
hit_type = hit_object.type
if hit_type == EventType.TIME_SHIFT:
time = quantize_to_beat(hit_object.value, bpm, offset)
if len(adjusted_times) > 0 and int(time) == adjusted_times[-1] and adjusted_hit_objects[-1].type != (EventType.LAST_ANCHOR or EventType.SLIDER_END):
time = move_to_next_tick(time, bpm)
adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
adjusted_times.append(int(time))
else:
adjusted_hit_objects.append(Event(EventType.TIME_SHIFT, time))
adjusted_times.append(int(time))
else:
adjusted_hit_objects.append(hit_object)
return adjusted_hit_objects
class Postprocessor(object):
def __init__(self, args: DictConfig):
"""Postprocessing stage that converts a list of Event objects to a beatmap file."""
self.curve_type_shorthand = {
"B": "Bezier",
"P": "PerfectCurve",
"C": "Catmull",
}
self.output_path = args.output_path
self.audio_path = args.audio_path
self.audio_filename = pathlib.Path(args.audio_path).name.split(".")[0]
self.beatmap_config = BeatmapConfig(
title=str(f"{self.audio_filename} ({args.title})"),
artist=str(args.artist),
title_unicode=str(args.title),
artist_unicode=str(args.artist),
audio_filename=pathlib.Path(args.audio_path).name,
slider_multiplier=float(args.slider_multiplier),
creator=str(args.creator),
version=str(args.version),
)
self.offset = args.offset
self.beat_length = 60000 / args.bpm
self.slider_multiplier = self.beatmap_config.slider_multiplier
self.bpm = args.bpm
self.resnap_objects = args.resnap_objects
def generate(self, generated_positions: list[Event]):
"""Generate a beatmap file.
Args:
events: List of Event objects.
Returns:
None. An .osu file will be generated.
"""
processed_events = []
for events in generated_positions:
# adjust hit objects to align with 1/4 beats
if self.resnap_objects:
events = adjust_hit_objects(events, self.bpm, self.offset)
hit_object_strings = []
time = 0
dist = 0
x = 256
y = 192
has_pos = False
new_combo = 0
ho_info = []
anchor_info = []
timing_point_strings = [
f"{self.offset},{self.beat_length},4,2,0,100,1,0"
]
for event in events:
hit_type = event.type
if hit_type == EventType.TIME_SHIFT:
time = event.value
continue
elif hit_type == EventType.DISTANCE:
# Find a point which is dist away from the last point but still within the playfield
dist = event.value
coordinates = calculate_coordinates((x, y), dist, 500, (512, 384))
pos = coordinates[np.random.randint(len(coordinates))]
x, y = pos
continue
elif hit_type == EventType.POS_X:
x = event.value
has_pos = True
continue
elif hit_type == EventType.POS_Y:
y = event.value
has_pos = True
continue
elif hit_type == EventType.NEW_COMBO:
new_combo = 4
continue
if hit_type == EventType.CIRCLE:
hit_object_strings.append(f"{int(round(x))},{int(round(y))},{int(round(time))},{1 | new_combo},0")
ho_info = []
elif hit_type == EventType.SPINNER:
ho_info = [time, new_combo]
elif hit_type == EventType.SPINNER_END and len(ho_info) == 2:
hit_object_strings.append(
f"{256},{192},{int(round(ho_info[0]))},{8 | ho_info[1]},0,{int(round(time))}"
)
ho_info = []
elif hit_type == EventType.SLIDER_HEAD:
ho_info = [x, y, time, new_combo]
anchor_info = []
elif hit_type == EventType.BEZIER_ANCHOR:
anchor_info.append(('B', x, y))
elif hit_type == EventType.PERFECT_ANCHOR:
anchor_info.append(('P', x, y))
elif hit_type == EventType.CATMULL_ANCHOR:
anchor_info.append(('C', x, y))
elif hit_type == EventType.RED_ANCHOR:
anchor_info.append(('B', x, y))
anchor_info.append(('B', x, y))
elif hit_type == EventType.LAST_ANCHOR:
ho_info.append(time)
anchor_info.append(('B', x, y))
elif hit_type == EventType.SLIDER_END and len(ho_info) == 5 and len(anchor_info) > 0:
curve_type = anchor_info[0][0]
span_duration = ho_info[4] - ho_info[2]
total_duration = time - ho_info[2]
if total_duration == 0 or span_duration == 0:
continue
slides = max(int(round(total_duration / span_duration)), 1)
control_points = "|".join(f"{int(round(cp[1]))}:{int(round(cp[2]))}" for cp in anchor_info)
slider_path = SliderPath(self.curve_type_shorthand[curve_type], np.array([(ho_info[0], ho_info[1])] + [(cp[1], cp[2]) for cp in anchor_info], dtype=float))
length = slider_path.get_distance()
req_length = length * position_to_progress(
slider_path,
np.array((x, y)),
) if has_pos else length - dist
if req_length < 1e-4:
continue
hit_object_strings.append(
f"{int(round(ho_info[0]))},{int(round(ho_info[1]))},{int(round(ho_info[2]))},{2 | ho_info[3]},0,{curve_type}|{control_points},{slides},{req_length}"
)
sv = span_duration / req_length / self.beat_length * self.slider_multiplier * -10000
timing_point_strings.append(
f"{int(round(ho_info[2]))},{sv},4,2,0,100,0,0"
)
new_combo = 0
# Write .osu file
with open(OSU_TEMPLATE_PATH, "r") as tf:
template = Template(tf.read())
hit_objects = {"hit_objects": "\n".join(hit_object_strings)}
timing_points = {"timing_points": "\n".join(timing_point_strings)}
beatmap_config = dataclasses.asdict(self.beatmap_config)
result = template.safe_substitute({**beatmap_config, **hit_objects, **timing_points})
processed_events.append(result)
osz_path = os.path.join(self.output_path, f"{self.audio_filename}_{t.time()}.osz")
with zipfile.ZipFile(osz_path, "w") as z:
for i, event in enumerate(processed_events):
osu_path = os.path.join(self.output_path, f"{i}{OSU_FILE_EXTENSION}")
with open(osu_path, "w") as osu_file:
osu_file.write(event)
z.write(osu_path, os.path.basename(osu_path))
z.write(self.audio_path, os.path.basename(self.audio_path))
print(f"Mapset saved {osz_path}")
z.close()