import pickle from pathlib import Path import torch from omegaconf import DictConfig from tqdm import tqdm from osudiffusion import timestep_embedding from osudiffusion import repeat_type from osudiffusion import create_diffusion from osudiffusion import DiT from osuT5.dataset.data_utils import update_event_times from osuT5.tokenizer import Event, EventType def get_beatmap_idx(path) -> dict[int, int]: p = Path(path) with p.open("rb") as f: beatmap_idx = pickle.load(f) return beatmap_idx class DiffisionPipeline(object): def __init__(self, args: DictConfig): """Model inference stage that generates positions for distance events.""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.num_sampling_steps = args.num_sampling_steps self.cfg_scale = args.cfg_scale self.seq_len = args.seq_len self.num_classes = args.num_classes self.beatmap_idx = get_beatmap_idx(args.beatmap_idx) self.style_id = args.style_id self.refine_iters = args.refine_iters self.use_amp = args.use_amp if self.style_id in self.beatmap_idx: self.class_label = self.beatmap_idx[self.style_id] else: print(f"Beatmap ID {self.style_id} not found in dataset, using default style.") self.class_label = self.num_classes def generate(self, model: DiT, events: list[Event], refine_model: DiT = None) -> list[Event]: """Generate position events for distance events in the Event list. Args: model: Trained model to use for inference. events: List of Event objects with distance events. refine_model: Optional model to refine the generated positions. Returns: events: List of Event objects with position events. """ seq_o, seq_c, seq_len, seq_indices = self.events_to_sequence(events) seq_o = seq_o - seq_o[0] # Normalize to relative time print(f"seq len {seq_len}") diffusion = create_diffusion( str(self.num_sampling_steps), noise_schedule="squaredcos_cap_v2", ) # Create banded matrix attention mask for increased sequence length attn_mask = torch.full((seq_len, seq_len), True, dtype=torch.bool, device=self.device) for i in range(seq_len): attn_mask[max(0, i - self.seq_len): min(seq_len, i + self.seq_len), i] = False class_labels = [self.class_label] # Create sampling noise: n = len(class_labels) z = torch.randn(n, 2, seq_len, device=self.device) o = seq_o.repeat(n, 1).to(self.device) c = seq_c.repeat(n, 1, 1).to(self.device) y = torch.tensor(class_labels, device=self.device) # Setup classifier-free guidance: z = torch.cat([z, z], 0) o = torch.cat([o, o], 0) c = torch.cat([c, c], 0) y_null = torch.tensor([self.num_classes] * n, device=self.device) y = torch.cat([y, y_null], 0) model_kwargs = dict(o=o, c=c, y=y, cfg_scale=self.cfg_scale, attn_mask=attn_mask) def to_positions(samples): samples, _ = samples.chunk(2, dim=0) # Remove null class samples samples *= torch.tensor((512, 384), device=self.device).repeat(n, 1).unsqueeze(2) return samples.cpu() # Sample images: samples = diffusion.p_sample_loop( model.forward_with_cfg, z.shape, z, clip_denoised=True, model_kwargs=model_kwargs, progress=True, device=self.device, ) if refine_model is not None: # Refine result with refine model for _ in tqdm(range(self.refine_iters)): t = torch.tensor([0] * samples.shape[0], device=self.device) with torch.no_grad(): out = diffusion.p_sample( model.forward_with_cfg, samples, t, clip_denoised=True, model_kwargs=model_kwargs, ) samples = out["sample"] positions = to_positions(samples) return self.events_with_pos(events, positions.squeeze(0), seq_indices) @staticmethod def events_to_sequence(events: list[Event]) -> tuple[torch.Tensor, torch.Tensor, int, dict[int, int]]: # Calculate the time of every event and interpolate time for control point events event_times = [] update_event_times(events, event_times) # Calculate the number of repeats for each slider end event # Convert to vectorized form for osu-diffusion nc_types = [EventType.CIRCLE, EventType.SLIDER_HEAD] event_index = { EventType.CIRCLE: 0, EventType.SPINNER: 2, EventType.SPINNER_END: 3, EventType.SLIDER_HEAD: 4, EventType.BEZIER_ANCHOR: 6, EventType.PERFECT_ANCHOR: 7, EventType.CATMULL_ANCHOR: 8, EventType.RED_ANCHOR: 9, EventType.LAST_ANCHOR: 10, EventType.SLIDER_END: 11, } seq_indices = {} indices = [] data_chunks = [] distance = 0 new_combo = False head_time = 0 last_anchor_time = 0 for i, event in enumerate(events): indices.append(i) if event.type == EventType.DISTANCE: distance = event.value elif event.type == EventType.NEW_COMBO: new_combo = True elif event.type in event_index: time = event_times[i] index = event_index[event.type] # Handle NC index offset if event.type in nc_types and new_combo: index += 1 new_combo = False # Add slider end repeats index offset if event.type == EventType.SLIDER_END: span_duration = last_anchor_time - head_time total_duration = time - head_time repeats = max(int(round(total_duration / span_duration)), 1) if span_duration > 0 else 1 index += repeat_type(repeats) elif event.type == EventType.SLIDER_HEAD: head_time = time elif event.type == EventType.LAST_ANCHOR: last_anchor_time = time features = torch.zeros(18) features[0] = time features[1] = distance features[index + 2] = 1 data_chunks.append(features) for j in indices: seq_indices[j] = len(data_chunks) - 1 indices = [] seq = torch.stack(data_chunks, 0) seq = torch.swapaxes(seq, 0, 1) seq_o = seq[0, :] seq_d = seq[1, :] seq_c = torch.concatenate( [ timestep_embedding(seq_d, 128).T, seq[2:, :], ], 0, ) return seq_o, seq_c, seq.shape[1], seq_indices @staticmethod def events_with_pos(events: list[Event], sampled_seq: torch.Tensor, seq_indices: dict[int, int]) -> list[Event]: new_events = [] for i, event in enumerate(events): if event.type == EventType.DISTANCE: try: index = seq_indices[i] pos_x = sampled_seq[0, index].item() pos_y = sampled_seq[1, index].item() new_events.append(Event(EventType.POS_X, int(round(pos_x)))) new_events.append(Event(EventType.POS_Y, int(round(pos_y)))) except KeyError: print(f"Warning: Key {i} not found in seq_indices. Skipping event.") else: new_events.append(event) return new_events