File size: 8,232 Bytes
7ef7abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import pickle
from pathlib import Path

import torch
from omegaconf import DictConfig
from tqdm import tqdm

from osudiffusion import timestep_embedding
from osudiffusion import repeat_type
from osudiffusion import create_diffusion
from osudiffusion import DiT
from osuT5.dataset.data_utils import update_event_times
from osuT5.tokenizer import Event, EventType


def get_beatmap_idx(path) -> dict[int, int]:
    p = Path(path)
    with p.open("rb") as f:
        beatmap_idx = pickle.load(f)
    return beatmap_idx


class DiffisionPipeline(object):
    def __init__(self, args: DictConfig):
        """Model inference stage that generates positions for distance events."""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_sampling_steps = args.num_sampling_steps
        self.cfg_scale = args.cfg_scale
        self.seq_len = args.seq_len
        self.num_classes = args.num_classes
        self.beatmap_idx = get_beatmap_idx(args.beatmap_idx)
        self.style_id = args.style_id
        self.refine_iters = args.refine_iters
        self.use_amp = args.use_amp

        if self.style_id in self.beatmap_idx:
            self.class_label = self.beatmap_idx[self.style_id]
        else:
            print(f"Beatmap ID {self.style_id} not found in dataset, using default style.")
            self.class_label = self.num_classes

    def generate(self, model: DiT, events: list[Event], refine_model: DiT = None) -> list[Event]:
        """Generate position events for distance events in the Event list.



        Args:

            model: Trained model to use for inference.

            events: List of Event objects with distance events.

            refine_model: Optional model to refine the generated positions.



        Returns:

            events: List of Event objects with position events.

        """

        seq_o, seq_c, seq_len, seq_indices = self.events_to_sequence(events)

        seq_o = seq_o - seq_o[0]  # Normalize to relative time
        print(f"seq len {seq_len}")

        diffusion = create_diffusion(
            str(self.num_sampling_steps),
            noise_schedule="squaredcos_cap_v2",
        )

        # Create banded matrix attention mask for increased sequence length
        attn_mask = torch.full((seq_len, seq_len), True, dtype=torch.bool, device=self.device)
        for i in range(seq_len):
            attn_mask[max(0, i - self.seq_len): min(seq_len, i + self.seq_len), i] = False

        class_labels = [self.class_label]

        # Create sampling noise:
        n = len(class_labels)
        z = torch.randn(n, 2, seq_len, device=self.device)
        o = seq_o.repeat(n, 1).to(self.device)
        c = seq_c.repeat(n, 1, 1).to(self.device)
        y = torch.tensor(class_labels, device=self.device)

        # Setup classifier-free guidance:
        z = torch.cat([z, z], 0)
        o = torch.cat([o, o], 0)
        c = torch.cat([c, c], 0)
        y_null = torch.tensor([self.num_classes] * n, device=self.device)
        y = torch.cat([y, y_null], 0)
        model_kwargs = dict(o=o, c=c, y=y, cfg_scale=self.cfg_scale, attn_mask=attn_mask)

        def to_positions(samples):
            samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
            samples *= torch.tensor((512, 384), device=self.device).repeat(n, 1).unsqueeze(2)
            return samples.cpu()

        # Sample images:
        samples = diffusion.p_sample_loop(
            model.forward_with_cfg,
            z.shape,
            z,
            clip_denoised=True,
            model_kwargs=model_kwargs,
            progress=True,
            device=self.device,
        )

        if refine_model is not None:
            # Refine result with refine model
            for _ in tqdm(range(self.refine_iters)):
                t = torch.tensor([0] * samples.shape[0], device=self.device)
                with torch.no_grad():
                    out = diffusion.p_sample(
                        model.forward_with_cfg,
                        samples,
                        t,
                        clip_denoised=True,
                        model_kwargs=model_kwargs,
                    )
                    samples = out["sample"]

        positions = to_positions(samples)
        return self.events_with_pos(events, positions.squeeze(0), seq_indices)

    @staticmethod
    def events_to_sequence(events: list[Event]) -> tuple[torch.Tensor, torch.Tensor, int, dict[int, int]]:
        # Calculate the time of every event and interpolate time for control point events
        event_times = []
        update_event_times(events, event_times)

        # Calculate the number of repeats for each slider end event
        # Convert to vectorized form for osu-diffusion
        nc_types = [EventType.CIRCLE, EventType.SLIDER_HEAD]
        event_index = {
            EventType.CIRCLE: 0,
            EventType.SPINNER: 2,
            EventType.SPINNER_END: 3,
            EventType.SLIDER_HEAD: 4,
            EventType.BEZIER_ANCHOR: 6,
            EventType.PERFECT_ANCHOR: 7,
            EventType.CATMULL_ANCHOR: 8,
            EventType.RED_ANCHOR: 9,
            EventType.LAST_ANCHOR: 10,
            EventType.SLIDER_END: 11,
        }

        seq_indices = {}
        indices = []
        data_chunks = []
        distance = 0
        new_combo = False
        head_time = 0
        last_anchor_time = 0
        for i, event in enumerate(events):
            indices.append(i)
            if event.type == EventType.DISTANCE:
                distance = event.value
            elif event.type == EventType.NEW_COMBO:
                new_combo = True
            elif event.type in event_index:
                time = event_times[i]
                index = event_index[event.type]

                # Handle NC index offset
                if event.type in nc_types and new_combo:
                    index += 1
                    new_combo = False

                # Add slider end repeats index offset
                if event.type == EventType.SLIDER_END:
                    span_duration = last_anchor_time - head_time
                    total_duration = time - head_time
                    repeats = max(int(round(total_duration / span_duration)), 1) if span_duration > 0 else 1
                    index += repeat_type(repeats)
                elif event.type == EventType.SLIDER_HEAD:
                    head_time = time
                elif event.type == EventType.LAST_ANCHOR:
                    last_anchor_time = time

                features = torch.zeros(18)
                features[0] = time
                features[1] = distance
                features[index + 2] = 1
                data_chunks.append(features)

                for j in indices:
                    seq_indices[j] = len(data_chunks) - 1
                indices = []

        seq = torch.stack(data_chunks, 0)
        seq = torch.swapaxes(seq, 0, 1)
        seq_o = seq[0, :]
        seq_d = seq[1, :]
        seq_c = torch.concatenate(
            [
                timestep_embedding(seq_d, 128).T,
                seq[2:, :],
            ],
            0,
        )

        return seq_o, seq_c, seq.shape[1], seq_indices


    @staticmethod
    def events_with_pos(events: list[Event], sampled_seq: torch.Tensor, seq_indices: dict[int, int]) -> list[Event]:
        new_events = []
        for i, event in enumerate(events):
            if event.type == EventType.DISTANCE:
                try:
                    index = seq_indices[i]
                    pos_x = sampled_seq[0, index].item()
                    pos_y = sampled_seq[1, index].item()
                    new_events.append(Event(EventType.POS_X, int(round(pos_x))))
                    new_events.append(Event(EventType.POS_Y, int(round(pos_y))))
                except KeyError:
                    print(f"Warning: Key {i} not found in seq_indices. Skipping event.")
            else:
                new_events.append(event)
        return new_events