| from __future__ import annotations
|
|
|
| from pathlib import Path
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from slider import Beatmap
|
| from tqdm import tqdm
|
|
|
| from omegaconf import DictConfig
|
|
|
| from osuT5.dataset import OsuParser
|
| from osuT5.dataset.data_utils import update_event_times
|
| from osuT5.tokenizer import Event, EventType, Tokenizer
|
| from osuT5.model import OsuT
|
|
|
| MILISECONDS_PER_SECOND = 1000
|
| MILISECONDS_PER_STEP = 10
|
|
|
| def top_k_sampling(logits, k):
|
| top_k_logits, top_k_indices = torch.topk(logits, k)
|
| top_k_probs = F.softmax(top_k_logits, dim=-1)
|
| sampled_index = torch.multinomial(top_k_probs, 1)
|
| sampled_token = top_k_indices.gather(-1, sampled_index)
|
| return sampled_token
|
|
|
| def preprocess_event(event, frame_time):
|
| if event.type == EventType.TIME_SHIFT:
|
| event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
|
| return event
|
|
|
| class Pipeline(object):
|
| def __init__(self, args: DictConfig, tokenizer: Tokenizer):
|
| """Model inference stage that processes sequences."""
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.tokenizer = tokenizer
|
| self.tgt_seq_len = args.data.tgt_seq_len
|
| self.frame_seq_len = args.data.src_seq_len - 1
|
| self.frame_size = args.model.spectrogram.hop_length
|
| self.sample_rate = args.model.spectrogram.sample_rate
|
| self.samples_per_sequence = self.frame_seq_len * self.frame_size
|
| self.sequence_stride = int(self.samples_per_sequence * args.data.sequence_stride)
|
| self.miliseconds_per_sequence = self.samples_per_sequence * MILISECONDS_PER_SECOND / self.sample_rate
|
| self.miliseconds_per_stride = self.sequence_stride * MILISECONDS_PER_SECOND / self.sample_rate
|
| self.beatmap_id = args.beatmap_id
|
| self.difficulty = args.difficulty
|
| self.center_pad_decoder = args.data.center_pad_decoder
|
| self.special_token_len = args.data.special_token_len
|
| self.diff_token_index = args.data.diff_token_index
|
| self.style_token_index = args.data.style_token_index
|
| self.max_pre_token_len = args.data.max_pre_token_len
|
| self.add_pre_tokens = args.data.add_pre_tokens
|
| self.add_gd_context = args.data.add_gd_context
|
| self.bpm = args.bpm
|
| self.offset = args.offset
|
| self.total_duration_ms = args.total_duration_ms
|
|
|
| print(f"Configuration: {args}")
|
|
|
| if self.add_gd_context:
|
| other_beatmap_path = Path(args.other_beatmap_path)
|
|
|
| if not other_beatmap_path.is_file():
|
| raise FileNotFoundError(f"Beatmap file {other_beatmap_path} not found.")
|
|
|
| other_beatmap = Beatmap.from_path(other_beatmap_path)
|
| self.other_beatmap_id = other_beatmap.beatmap_id
|
| self.other_difficulty = float(other_beatmap.stars())
|
| parser = OsuParser(tokenizer)
|
| self.other_events = parser.parse(other_beatmap)
|
| self.other_events, self.other_event_times = self._prepare_events(self.other_events)
|
|
|
| def _calculate_time_shifts(self, bpm: float, duration_ms: float, tick_rate: int, offset: float = 0) -> list[float]:
|
| """Calculate EventType.TIME_SHIFT events based on song's BPM and tick rate."""
|
| events = []
|
| ms_per_beat = 60000 / bpm
|
| ms_per_tick = ms_per_beat / tick_rate
|
| num_ticks = int(duration_ms // ms_per_tick)
|
|
|
| for i in range(num_ticks):
|
| events.append(float(int(i * ms_per_tick + offset)) )
|
|
|
| return events
|
|
|
| def generate_events(self, model, frames, tokens, encoder_outputs, beatmap_idx, total_steps):
|
| temperature = 0.9
|
| k = 10
|
|
|
| for _ in range(total_steps):
|
| out = model.forward(
|
| frames=frames,
|
| decoder_input_ids=tokens,
|
| decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
|
| encoder_outputs=encoder_outputs,
|
| beatmap_idx=beatmap_idx,
|
| )
|
| encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
|
| logits = out.logits
|
| logits = logits[:, -1, :] / temperature
|
| logits = self._filter(logits, 0.9)
|
| probabilities = F.softmax(logits, dim=-1)
|
| next_tokens = top_k_sampling(probabilities, k)
|
|
|
| tokens = torch.cat([tokens, next_tokens], dim=-1)
|
|
|
| eos_in_sentence = next_tokens == self.tokenizer.eos_id
|
| if eos_in_sentence.all():
|
| break
|
|
|
| return tokens
|
|
|
| def generate(self, model: OsuT, sequences: torch.Tensor, top_k: int = 50) -> list[Event]:
|
| """
|
| Generate a list of Event object lists and their timestamps given source sequences.
|
|
|
| Args:
|
| model: Trained model to use for inference.
|
| sequences: A list of batched source sequences.
|
| top_k: Number of top tokens to use for top-k sampling.
|
|
|
| Returns:
|
| events: List of Event object lists.
|
| event_times: Corresponding event times of Event object lists in milliseconds.
|
| """
|
| events = []
|
| event_times = []
|
| temperature = 0.95
|
|
|
| idx_dict = self.tokenizer.beatmap_idx
|
| beatmap_idx = torch.tensor([idx_dict.get(self.beatmap_id, 6666)], dtype=torch.long, device=self.device)
|
| style_token = self.tokenizer.encode_style(self.beatmap_id) if self.beatmap_id in idx_dict else self.tokenizer.style_unk
|
| diff_token = self.tokenizer.encode_diff(self.difficulty) if self.difficulty != -1 else self.tokenizer.diff_unk
|
|
|
| special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
|
| special_tokens[:, self.diff_token_index] = diff_token
|
| special_tokens[:, self.style_token_index] = style_token
|
|
|
| if self.add_gd_context:
|
| other_style_token = self.tokenizer.encode_style(self.other_beatmap_id) if self.other_beatmap_id in idx_dict else self.tokenizer.style_unk
|
| other_special_tokens = torch.empty((1, self.special_token_len), dtype=torch.long, device=self.device)
|
| other_special_tokens[:, self.diff_token_index] = self.tokenizer.encode_diff(self.other_difficulty)
|
| other_special_tokens[:, self.style_token_index] = other_style_token
|
| else:
|
| other_special_tokens = torch.empty((1, 0), dtype=torch.long, device=self.device)
|
|
|
| for sequence_index, frames in enumerate(tqdm(sequences)):
|
|
|
| frame_time = sequence_index * self.miliseconds_per_stride
|
| prev_events = self._get_events_time_range(
|
| events, event_times, frame_time - self.miliseconds_per_sequence, frame_time) if self.add_pre_tokens else []
|
| post_events = self._get_events_time_range(
|
| events, event_times, frame_time, frame_time + self.miliseconds_per_sequence)
|
|
|
| prev_tokens = self._encode(prev_events, frame_time)
|
| post_tokens = self._encode(post_events, frame_time)
|
| post_token_length = post_tokens.shape[1]
|
|
|
| if 0 <= self.max_pre_token_len < prev_tokens.shape[1]:
|
| prev_tokens = prev_tokens[:, -self.max_pre_token_len:]
|
|
|
|
|
| prefix = torch.cat([special_tokens, prev_tokens], dim=-1)
|
| if self.center_pad_decoder:
|
| prefix = F.pad(prefix, (self.tgt_seq_len // 2 - prefix.shape[1], 0), value=self.tokenizer.pad_id)
|
| prefix_length = prefix.shape[1]
|
|
|
|
|
| max_retries = 5
|
| attempt = 0
|
| result = []
|
|
|
| while attempt < max_retries and not result:
|
| attempt += 1
|
| try:
|
|
|
| tokens = torch.tensor([[self.tokenizer.sos_id]], dtype=torch.long, device=self.device)
|
| tokens = torch.cat([prefix, tokens, post_tokens], dim=-1)
|
|
|
|
|
| retry_frames = frames.clone().to(self.device).unsqueeze(0)
|
| encoder_outputs = None
|
|
|
| while tokens.shape[-1] < self.tgt_seq_len:
|
| out = model.forward(
|
| frames=retry_frames,
|
| decoder_input_ids=tokens,
|
| decoder_attention_mask=tokens.ne(self.tokenizer.pad_id),
|
| encoder_outputs=encoder_outputs,
|
|
|
| )
|
| encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions)
|
|
|
| logits = out.logits[:, -1, :]
|
| logits = logits / temperature
|
| logits = self._filter(logits, top_p=0.9, top_k=60)
|
| probabilities = F.softmax(logits, dim=-1)
|
| next_tokens = torch.multinomial(probabilities, 1)
|
|
|
| tokens = torch.cat([tokens, next_tokens], dim=-1)
|
|
|
| eos_in_sentence = next_tokens == self.tokenizer.eos_id
|
| if eos_in_sentence.all():
|
| break
|
|
|
| predicted_tokens = tokens[:, prefix_length + 1 + post_token_length:]
|
| result = self._decode(predicted_tokens[0], frame_time)
|
|
|
|
|
| if len(result) > 10 and not any(event.type == EventType.NEW_COMBO for event in result):
|
|
|
| result = []
|
|
|
|
|
| except Exception as e:
|
|
|
| result = []
|
|
|
| events += result
|
|
|
| self._update_event_times(events, event_times, frame_time)
|
|
|
| return events
|
|
|
| def _prepare_events(self, events: list[Event]) -> tuple[list[Event], list[float]]:
|
| """Pre-process raw list of events for inference. Calculates event times and removes redundant time shifts."""
|
| ct = 0
|
| event_times = []
|
| for event in events:
|
| if event.type == EventType.TIME_SHIFT:
|
| ct = event.value
|
| event_times.append(ct)
|
|
|
|
|
| delete_next_time_shift = False
|
| for i in range(len(events) - 1, -1, -1):
|
| if events[i].type == EventType.TIME_SHIFT and delete_next_time_shift:
|
| delete_next_time_shift = False
|
| del events[i]
|
| del event_times[i]
|
| continue
|
| elif events[i].type in [EventType.BEZIER_ANCHOR, EventType.PERFECT_ANCHOR, EventType.CATMULL_ANCHOR,
|
| EventType.RED_ANCHOR]:
|
| delete_next_time_shift = True
|
|
|
|
|
|
|
|
|
| return events, event_times
|
|
|
| def _get_events_time_range(self, events: list[Event], event_times: list[float], start_time: float, end_time: float):
|
|
|
| s = 0
|
| for i in range(len(event_times) - 1, -1, -1):
|
| if event_times[i] < start_time:
|
| s = i + 1
|
| break
|
| e = 0
|
| for i in range(len(event_times) - 1, -1, -1):
|
| if event_times[i] < end_time:
|
| e = i + 1
|
| break
|
| return events[s:e]
|
|
|
| def _update_event_times(self, events: list[Event], event_times: list[float], frame_time: float):
|
| update_event_times(events, event_times, frame_time + self.miliseconds_per_sequence)
|
|
|
|
|
| def _encode(self, events: list[Event], frame_time: float) -> torch.Tensor:
|
| try:
|
|
|
| tokens = torch.empty((1, len(events)), dtype=torch.long)
|
| for i, event in enumerate(events):
|
| if event.type == EventType.TIME_SHIFT:
|
| event = Event(type=event.type, value=int((event.value - frame_time) / MILISECONDS_PER_STEP))
|
| tokens[0, i] = self.tokenizer.encode(event)
|
| return tokens.to(self.device)
|
| except Exception as e:
|
|
|
|
|
| return torch.empty((1, 0), dtype=torch.long, device=self.device)
|
| def _decode(self, tokens: torch.Tensor, frame_time: float) -> list[Event]:
|
| """Converts a list of tokens into Event objects and converts to absolute time values.
|
|
|
| Args:
|
| tokens: List of tokens.
|
| frame time: Start time of current source sequence.
|
|
|
| Returns:
|
| events: List of Event objects.
|
| """
|
| events = []
|
| for token in tokens:
|
| if token == self.tokenizer.eos_id:
|
| break
|
|
|
| try:
|
| event = self.tokenizer.decode(token.item())
|
| except:
|
| continue
|
|
|
| if event.type == EventType.TIME_SHIFT:
|
| event.value = frame_time + event.value * MILISECONDS_PER_STEP
|
|
|
| events.append(event)
|
|
|
| return events
|
|
|
| def _filter(self, logits: torch.Tensor, top_p: float = 0.75, top_k: int = 1, filter_value: float = -float("Inf")) -> torch.Tensor:
|
| """Filter a distribution of logits using nucleus (top-p) and/or top-k filtering.
|
| """
|
| logits = top_k_logits(logits, top_k) if top_k > 0 else logits
|
|
|
| if 0.0 < top_p < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
| sorted_indices_to_remove = cumulative_probs > top_p
|
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| sorted_indices_to_remove[..., 0] = 0
|
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| logits[indices_to_remove] = filter_value
|
|
|
| return logits
|
| def top_k_logits(logits, k):
|
| """
|
| Keep only the top-k tokens with highest probabilities.
|
|
|
| Args:
|
| logits: Logits distribution of shape (batch size, vocabulary size).
|
| k: Number of top tokens to keep.
|
|
|
| Returns:
|
| logits with non-top-k elements set to negative infinity.
|
| """
|
| values, indices = torch.topk(logits, k)
|
| min_values = values[:, -1].unsqueeze(-1).expand_as(logits)
|
| return torch.where(logits < min_values, torch.full_like(logits, float("-Inf")), logits)
|
|
|