osu_mapper2 / osuT5 /inference /pipeline.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn.functional as F
from slider import Beatmap
from tqdm import tqdm
from omegaconf import DictConfig
from osuT5.dataset import OsuParser
from osuT5.dataset.data_utils import update_event_times
from osuT5.tokenizer import Event, EventType, Tokenizer
from osuT5.model import OsuT
MILISECONDS_PER_SECOND = 1000
MILISECONDS_PER_STEP = 10
def top_k_sampling(logits, k):
top_k_logits, top_k_indices = torch.topk(logits, k)
top_k_probs = F.softmax(top_k_logits, dim=-1)
sampled_index = torch.multinomial(top_k_probs, 1)
sampled_token = top_k_indices.gather(-1, sampled_index)
return sampled_token
def preprocess_event(event, frame_time):
if event.type == EventType.TIME_SHIFT:
event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
return event
class Pipeline(object):
def __init__(self, args: DictConfig, tokenizer: Tokenizer):
"""Model inference stage that processes sequences."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = tokenizer
self.tgt_seq_len = args.data.tgt_seq_len
self.frame_seq_len = args.data.src_seq_len - 1
self.frame_size = args.model.spectrogram.hop_length
self.sample_rate = args.model.spectrogram.sample_rate
self.samples_per_sequence = self.frame_seq_len * self.frame_size
self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride)
self.miliseconds_per_sequence = self.samples_per_sequence * MILISECONDS_PER_SECOND / self.sample_rate
self.miliseconds_per_stride = self.sequence_stride * MILISECONDS_PER_SECOND / self.sample_rate
self.beatmap_id = args.beatmap_id
self.difficulty = args.difficulty
self.center_pad_decoder = args.data.center_pad_decoder
self.special_token_len = args.data.special_token_len
self.diff_token_index = args.data.diff_token_index
self.style_token_index = args.data.style_token_index
self.max_pre_token_len = args.data.max_pre_token_len
self.add_pre_tokens = args.data.add_pre_tokens
self.add_gd_context = args.data.add_gd_context
self.bpm = args.bpm
self.offset = args.offset
self.total_duration_ms = args.total_duration_ms
print(f"Configuration: {args}")
if self.add_gd_context:
other_beatmap_path = Path(args.other_beatmap_path)
if not other_beatmap_path.is_file():
raise FileNotFoundError(f"Beatmap file {other_beatmap_path} not found.")
other_beatmap = Beatmap.from_path(other_beatmap_path)
self.other_beatmap_id = other_beatmap.beatmap_id
self.other_difficulty = float(other_beatmap.stars())
parser = OsuParser(tokenizer)
self.other_events = parser.parse(other_beatmap)
self.other_events, self.other_event_times = self._prepare_events(self.other_events)
def _calculate_time_shifts(self, bpm: float, duration_ms: float, tick_rate: int, offset: float = 0) -> list[float]:
"""Calculate EventType.TIME_SHIFT events based on song's BPM and tick rate."""
events = []
ms_per_beat = 60000 / bpm # 60000 ms per minute
ms_per_tick = ms_per_beat / tick_rate
num_ticks = int(duration_ms // ms_per_tick)
for i in range(num_ticks):
events.append(float(int(i * ms_per_tick + offset)) )
return events
def generate_events(self, model, frames, tokens, encoder_outputs, beatmap_idx, total_steps):
temperature = 0.9
k = 10 # top-k sampling
for _ in range(total_steps):
out = model.forward(
frames=frames,
decoder_input_ids=tokens,
decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
encoder_outputs=encoder_outputs,
beatmap_idx=beatmap_idx,
)
encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
logits = out.logits
logits = logits[:, -1, :] / temperature
logits = self._filter(logits, 0.9)
probabilities = F.softmax(logits, dim=-1)
next_tokens = top_k_sampling(probabilities, k)
tokens = torch.cat([tokens, next_tokens], dim=-1)
eos_in_sentence = next_tokens == self.tokenizer.eos_id
if eos_in_sentence.all():
break
return tokens
def generate(self, model: OsuT, sequences: torch.Tensor, top_k: int = 50) -> list[Event]:
"""
Generate a list of Event object lists and their timestamps given source sequences.
Args:
model: Trained model to use for inference.
sequences: A list of batched source sequences.
top_k: Number of top tokens to use for top-k sampling.
Returns:
events: List of Event object lists.
event_times: Corresponding event times of Event object lists in milliseconds.
"""
events = []
event_times = []
temperature = 0.95
idx_dict = self.tokenizer.beatmap_idx
beatmap_idx = torch.tensor([idx_dict.get(self.beatmap_id, 6666)], dtype=torch.long, device=self.device)
style_token = self.tokenizer.encode_style(self.beatmap_id) if self.beatmap_id in idx_dict else self.tokenizer.style_unk
diff_token = self.tokenizer.encode_diff(self.difficulty) if self.difficulty != -1 else self.tokenizer.diff_unk
special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
special_tokens[:, self.diff_token_index] = diff_token
special_tokens[:, self.style_token_index] = style_token
if self.add_gd_context:
other_style_token = self.tokenizer.encode_style(self.other_beatmap_id) if self.other_beatmap_id in idx_dict else self.tokenizer.style_unk
other_special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
other_special_tokens[:, self.diff_token_index] = self.tokenizer.encode_diff(self.other_difficulty)
other_special_tokens[:, self.style_token_index] = other_style_token
else:
other_special_tokens = torch.empty((1, 0), dtype=torch.long, device=self.device)
for sequence_index, frames in enumerate(tqdm(sequences)):
# Get tokens of previous frame
frame_time = sequence_index * self.miliseconds_per_stride
prev_events = self._get_events_time_range(
events, event_times, frame_time - self.miliseconds_per_sequence, frame_time) if self.add_pre_tokens else []
post_events = self._get_events_time_range(
events, event_times, frame_time, frame_time + self.miliseconds_per_sequence)
prev_tokens = self._encode(prev_events, frame_time)
post_tokens = self._encode(post_events, frame_time)
post_token_length = post_tokens.shape[1]
if 0 <= self.max_pre_token_len < prev_tokens.shape[1]:
prev_tokens = prev_tokens[:, -self.max_pre_token_len:]
# Get prefix tokens
prefix = torch.cat([special_tokens, prev_tokens], dim=-1)
if self.center_pad_decoder:
prefix = F.pad(prefix, (self.tgt_seq_len // 2 - prefix.shape[1], 0), value=self.tokenizer.pad_id)
prefix_length = prefix.shape[1]
max_retries = 5
attempt = 0
result = []
while attempt < max_retries and not result:
attempt += 1
try:
# Reset tokens
tokens = torch.tensor([[self.tokenizer.sos_id]], dtype=torch.long, device=self.device)
tokens = torch.cat([prefix, tokens, post_tokens], dim=-1)
# Ensure frames are properly reset for each retry
retry_frames = frames.clone().to(self.device).unsqueeze(0)
encoder_outputs = None
while tokens.shape[-1] < self.tgt_seq_len:
out = model.forward(
frames=retry_frames,
decoder_input_ids=tokens,
decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
encoder_outputs=encoder_outputs,
#beatmap_idx=beatmap_idx,
)
encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
logits = out.logits[:, -1, :]
logits = logits / temperature
logits = self._filter(logits, top_p=0.9, top_k=60)
probabilities = F.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probabilities, 1)
tokens = torch.cat([tokens, next_tokens], dim=-1)
eos_in_sentence = next_tokens == self.tokenizer.eos_id
if eos_in_sentence.all():
break
predicted_tokens = tokens[:, prefix_length + 1 + post_token_length:]
result = self._decode(predicted_tokens[0], frame_time)
# if no new combo in result, retry;
if len(result) > 10 and not any(event.type == EventType.NEW_COMBO for event in result):
#print("No new combo in result; retrying...")
result = []
except Exception as e:
#print(f"Attempt {attempt} encountered an error: {e}")
result = [] # Ensure result is empty to trigger retry
events += result
self._update_event_times(events, event_times, frame_time)
return events
def _prepare_events(self, events: list[Event]) -> tuple[list[Event], list[float]]:
"""Pre-process raw list of events for inference. Calculates event times and removes redundant time shifts."""
ct = 0
event_times = []
for event in events:
if event.type == EventType.TIME_SHIFT:
ct = event.value
event_times.append(ct)
# Loop through the events in reverse to remove any time shifts that occur before anchor events
delete_next_time_shift = False
for i in range(len(events) - 1, -1, -1):
if events[i].type == EventType.TIME_SHIFT and delete_next_time_shift:
delete_next_time_shift = False
del events[i]
del event_times[i]
continue
elif events[i].type in [EventType.BEZIER_ANCHOR, EventType.PERFECT_ANCHOR, EventType.CATMULL_ANCHOR,
EventType.RED_ANCHOR]:
delete_next_time_shift = True
# duplicate events 3 times
return events, event_times
def _get_events_time_range(self, events: list[Event], event_times: list[float], start_time: float, end_time: float):
# Look from the end of the list
s = 0
for i in range(len(event_times) - 1, -1, -1):
if event_times[i] < start_time:
s = i + 1
break
e = 0
for i in range(len(event_times) - 1, -1, -1):
if event_times[i] < end_time:
e = i + 1
break
return events[s:e]
def _update_event_times(self, events: list[Event], event_times: list[float], frame_time: float):
update_event_times(events, event_times, frame_time + self.miliseconds_per_sequence)
def _encode(self, events: list[Event], frame_time: float) -> torch.Tensor:
try:
tokens = torch.empty((1, len(events)), dtype=torch.long)
for i, event in enumerate(events):
if event.type == EventType.TIME_SHIFT:
event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
tokens[0, i] = self.tokenizer.encode(event)
return tokens.to(self.device)
except Exception as e:
#print(f"Error encoding events: {events}")
#print(e)
return torch.empty((1, 0), dtype=torch.long, device=self.device)
def _decode(self, tokens: torch.Tensor, frame_time: float) -> list[Event]:
"""Converts a list of tokens into Event objects and converts to absolute time values.
Args:
tokens: List of tokens.
frame time: Start time of current source sequence.
Returns:
events: List of Event objects.
"""
events = []
for token in tokens:
if token == self.tokenizer.eos_id:
break
try:
event = self.tokenizer.decode(token.item())
except:
continue
if event.type == EventType.TIME_SHIFT:
event.value = frame_time + event.value * MILISECONDS_PER_STEP
events.append(event)
return events
def _filter(self, logits: torch.Tensor, top_p: float = 0.75, top_k: int = 1, filter_value: float = -float("Inf")) -> torch.Tensor:
"""Filter a distribution of logits using nucleus (top-p) and/or top-k filtering.
"""
logits = top_k_logits(logits, top_k) if top_k > 0 else logits
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def top_k_logits(logits, k):
"""
Keep only the top-k tokens with highest probabilities.
Args:
logits: Logits distribution of shape (batch size, vocabulary size).
k: Number of top tokens to keep.
Returns:
logits with non-top-k elements set to negative infinity.
"""
values, indices = torch.topk(logits, k)
min_values = values[:, -1].unsqueeze(-1).expand_as(logits)
return torch.where(logits < min_values, torch.full_like(logits, float("-Inf")), logits)