| | import pickle
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| | from omegaconf import DictConfig
|
| | from tqdm import tqdm
|
| |
|
| | from osudiffusion import timestep_embedding
|
| | from osudiffusion import repeat_type
|
| | from osudiffusion import create_diffusion
|
| | from osudiffusion import DiT
|
| | from osuT5.dataset.data_utils import update_event_times
|
| | from osuT5.tokenizer import Event, EventType
|
| |
|
| |
|
| | def get_beatmap_idx(path) -> dict[int, int]:
|
| | p = Path(path)
|
| | with p.open("rb") as f:
|
| | beatmap_idx = pickle.load(f)
|
| | return beatmap_idx
|
| |
|
| |
|
| | class DiffisionPipeline(object):
|
| | def __init__(self, args: DictConfig):
|
| | """Model inference stage that generates positions for distance events."""
|
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | self.num_sampling_steps = args.num_sampling_steps
|
| | self.cfg_scale = args.cfg_scale
|
| | self.seq_len = args.seq_len
|
| | self.num_classes = args.num_classes
|
| | self.beatmap_idx = get_beatmap_idx(args.beatmap_idx)
|
| | self.style_id = args.style_id
|
| | self.refine_iters = args.refine_iters
|
| | self.use_amp = args.use_amp
|
| |
|
| | if self.style_id in self.beatmap_idx:
|
| | self.class_label = self.beatmap_idx[self.style_id]
|
| | else:
|
| | print(f"Beatmap ID {self.style_id} not found in dataset, using default style.")
|
| | self.class_label = self.num_classes
|
| |
|
| | def generate(self, model: DiT, events: list[Event], refine_model: DiT = None) -> list[Event]:
|
| | """Generate position events for distance events in the Event list.
|
| |
|
| | Args:
|
| | model: Trained model to use for inference.
|
| | events: List of Event objects with distance events.
|
| | refine_model: Optional model to refine the generated positions.
|
| |
|
| | Returns:
|
| | events: List of Event objects with position events.
|
| | """
|
| |
|
| | seq_o, seq_c, seq_len, seq_indices = self.events_to_sequence(events)
|
| |
|
| | seq_o = seq_o - seq_o[0]
|
| | print(f"seq len {seq_len}")
|
| |
|
| | diffusion = create_diffusion(
|
| | str(self.num_sampling_steps),
|
| | noise_schedule="squaredcos_cap_v2",
|
| | )
|
| |
|
| |
|
| | attn_mask = torch.full((seq_len, seq_len), True, dtype=torch.bool, device=self.device)
|
| | for i in range(seq_len):
|
| | attn_mask[max(0, i - self.seq_len): min(seq_len, i + self.seq_len), i] = False
|
| |
|
| | class_labels = [self.class_label]
|
| |
|
| |
|
| | n = len(class_labels)
|
| | z = torch.randn(n, 2, seq_len, device=self.device)
|
| | o = seq_o.repeat(n, 1).to(self.device)
|
| | c = seq_c.repeat(n, 1, 1).to(self.device)
|
| | y = torch.tensor(class_labels, device=self.device)
|
| |
|
| |
|
| | z = torch.cat([z, z], 0)
|
| | o = torch.cat([o, o], 0)
|
| | c = torch.cat([c, c], 0)
|
| | y_null = torch.tensor([self.num_classes] * n, device=self.device)
|
| | y = torch.cat([y, y_null], 0)
|
| | model_kwargs = dict(o=o, c=c, y=y, cfg_scale=self.cfg_scale, attn_mask=attn_mask)
|
| |
|
| | def to_positions(samples):
|
| | samples, _ = samples.chunk(2, dim=0)
|
| | samples *= torch.tensor((512, 384), device=self.device).repeat(n, 1).unsqueeze(2)
|
| | return samples.cpu()
|
| |
|
| |
|
| | samples = diffusion.p_sample_loop(
|
| | model.forward_with_cfg,
|
| | z.shape,
|
| | z,
|
| | clip_denoised=True,
|
| | model_kwargs=model_kwargs,
|
| | progress=True,
|
| | device=self.device,
|
| | )
|
| |
|
| | if refine_model is not None:
|
| |
|
| | for _ in tqdm(range(self.refine_iters)):
|
| | t = torch.tensor([0] * samples.shape[0], device=self.device)
|
| | with torch.no_grad():
|
| | out = diffusion.p_sample(
|
| | model.forward_with_cfg,
|
| | samples,
|
| | t,
|
| | clip_denoised=True,
|
| | model_kwargs=model_kwargs,
|
| | )
|
| | samples = out["sample"]
|
| |
|
| | positions = to_positions(samples)
|
| | return self.events_with_pos(events, positions.squeeze(0), seq_indices)
|
| |
|
| | @staticmethod
|
| | def events_to_sequence(events: list[Event]) -> tuple[torch.Tensor, torch.Tensor, int, dict[int, int]]:
|
| |
|
| | event_times = []
|
| | update_event_times(events, event_times)
|
| |
|
| |
|
| |
|
| | nc_types = [EventType.CIRCLE, EventType.SLIDER_HEAD]
|
| | event_index = {
|
| | EventType.CIRCLE: 0,
|
| | EventType.SPINNER: 2,
|
| | EventType.SPINNER_END: 3,
|
| | EventType.SLIDER_HEAD: 4,
|
| | EventType.BEZIER_ANCHOR: 6,
|
| | EventType.PERFECT_ANCHOR: 7,
|
| | EventType.CATMULL_ANCHOR: 8,
|
| | EventType.RED_ANCHOR: 9,
|
| | EventType.LAST_ANCHOR: 10,
|
| | EventType.SLIDER_END: 11,
|
| | }
|
| |
|
| | seq_indices = {}
|
| | indices = []
|
| | data_chunks = []
|
| | distance = 0
|
| | new_combo = False
|
| | head_time = 0
|
| | last_anchor_time = 0
|
| | for i, event in enumerate(events):
|
| | indices.append(i)
|
| | if event.type == EventType.DISTANCE:
|
| | distance = event.value
|
| | elif event.type == EventType.NEW_COMBO:
|
| | new_combo = True
|
| | elif event.type in event_index:
|
| | time = event_times[i]
|
| | index = event_index[event.type]
|
| |
|
| |
|
| | if event.type in nc_types and new_combo:
|
| | index += 1
|
| | new_combo = False
|
| |
|
| |
|
| | if event.type == EventType.SLIDER_END:
|
| | span_duration = last_anchor_time - head_time
|
| | total_duration = time - head_time
|
| | repeats = max(int(round(total_duration / span_duration)), 1) if span_duration > 0 else 1
|
| | index += repeat_type(repeats)
|
| | elif event.type == EventType.SLIDER_HEAD:
|
| | head_time = time
|
| | elif event.type == EventType.LAST_ANCHOR:
|
| | last_anchor_time = time
|
| |
|
| | features = torch.zeros(18)
|
| | features[0] = time
|
| | features[1] = distance
|
| | features[index + 2] = 1
|
| | data_chunks.append(features)
|
| |
|
| | for j in indices:
|
| | seq_indices[j] = len(data_chunks) - 1
|
| | indices = []
|
| |
|
| | seq = torch.stack(data_chunks, 0)
|
| | seq = torch.swapaxes(seq, 0, 1)
|
| | seq_o = seq[0, :]
|
| | seq_d = seq[1, :]
|
| | seq_c = torch.concatenate(
|
| | [
|
| | timestep_embedding(seq_d, 128).T,
|
| | seq[2:, :],
|
| | ],
|
| | 0,
|
| | )
|
| |
|
| | return seq_o, seq_c, seq.shape[1], seq_indices
|
| |
|
| |
|
| | @staticmethod
|
| | def events_with_pos(events: list[Event], sampled_seq: torch.Tensor, seq_indices: dict[int, int]) -> list[Event]:
|
| | new_events = []
|
| | for i, event in enumerate(events):
|
| | if event.type == EventType.DISTANCE:
|
| | try:
|
| | index = seq_indices[i]
|
| | pos_x = sampled_seq[0, index].item()
|
| | pos_y = sampled_seq[1, index].item()
|
| | new_events.append(Event(EventType.POS_X, int(round(pos_x))))
|
| | new_events.append(Event(EventType.POS_Y, int(round(pos_y))))
|
| | except KeyError:
|
| | print(f"Warning: Key {i} not found in seq_indices. Skipping event.")
|
| | else:
|
| | new_events.append(event)
|
| | return new_events
|
| |
|