File size: 5,708 Bytes
12a8e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import json
from pathlib import Path
import numpy.typing as npt

import hydra
import torch
from omegaconf import DictConfig
from slider import Beatmap
from torch.utils.data import IterableDataset

from classifier.libs.dataset import OsuParser
from classifier.libs.dataset.data_utils import load_audio_file
from classifier.libs.dataset.ors_dataset import STEPS_PER_MILLISECOND
from classifier.libs.model.model import OsuClassifierOutput
from classifier.libs.tokenizer import Tokenizer, Event, EventType
from classifier.libs.utils import load_ckpt


def iterate_examples(
        beatmap: Beatmap,
        audio: npt.NDArray,
        model_args: DictConfig,
        tokenizer: Tokenizer,
        device: torch.device
):
    frame_seq_len = model_args.data.src_seq_len - 1
    frame_size = model_args.data.hop_length
    sample_rate = model_args.data.sample_rate
    samples_per_sequence = frame_seq_len * frame_size

    parser = OsuParser(model_args, tokenizer)
    events, event_times = parser.parse(beatmap)

    for sample in range(0, len(audio) - samples_per_sequence, samples_per_sequence):
        example = create_example(events, event_times, audio, sample / sample_rate, model_args, tokenizer, device)
        yield example


class ExampleDataset(IterableDataset):
    def __init__(self, beatmap, audio, classifier_args, classifier_tokenizer, device):
        self.beatmap = beatmap
        self.audio = audio
        self.classifier_args = classifier_args
        self.classifier_tokenizer = classifier_tokenizer
        self.device = device

    def __iter__(self):
        return iterate_examples(
            self.beatmap,
            self.audio,
            self.classifier_args,
            self.classifier_tokenizer,
            self.device
        )


def create_example(
        events: list[Event],
        event_times: list[float],
        audio: npt.NDArray,
        time: float,
        model_args: DictConfig,
        tokenizer: Tokenizer,
        device: torch.device,
        unsqueeze: bool = False,
):
    frame_seq_len = model_args.data.src_seq_len - 1
    frame_size = model_args.data.hop_length
    sample_rate = model_args.data.sample_rate
    samples_per_sequence = frame_seq_len * frame_size
    sequence_duration = samples_per_sequence / sample_rate

    # Get audio frames
    frame_start = int(time * sample_rate)
    frames = audio[frame_start:frame_start + samples_per_sequence]
    frames = torch.from_numpy(frames).to(torch.float32).to(device)

    # Get the events between time and time + sequence_duration
    events = [event for event, event_time in zip(events, event_times) if
              time <= event_time / 1000 < time + sequence_duration]
    # Normalize time shifts
    for i, event in enumerate(events):
        if event.type == EventType.TIME_SHIFT:
            events[i] = Event(EventType.TIME_SHIFT, int((event.value - time * 1000) * STEPS_PER_MILLISECOND))

    # Tokenize the events
    tokens = torch.full((model_args.data.tgt_seq_len,), tokenizer.pad_id, dtype=torch.long)
    for i in range(min(len(events), model_args.data.tgt_seq_len)):
        tokens[i] = tokenizer.encode(events[i])
    tokens = tokens.to(device)

    if unsqueeze:
        tokens = tokens.unsqueeze(0)
        frames = frames.unsqueeze(0)

    return {
        "decoder_input_ids": tokens,
        "decoder_attention_mask": tokens != tokenizer.pad_id,
        "frames": frames,
    }


def create_example_from_path(
        beatmap_path: str,
        audio_path: str,
        time: float,
        model_args: DictConfig,
        tokenizer: Tokenizer,
        device: torch.device,
        unsqueeze: bool = False,
):
    sample_rate = model_args.data.sample_rate

    beatmap_path = Path(beatmap_path)
    beatmap = Beatmap.from_path(beatmap_path)

    # Get audio frames
    if audio_path == '':
        audio_path = beatmap_path.parent / beatmap.audio_filename

    audio = load_audio_file(audio_path, sample_rate)

    parser = OsuParser(model_args, tokenizer)
    events, event_times = parser.parse(beatmap)

    return create_example(events, event_times, audio, time, model_args, tokenizer, device, unsqueeze)


def get_mapper_names(path: str):
    path = Path(path)

    # Load JSON data from file
    with open(path, 'r') as file:
        data = json.load(file)

    # Populate beatmap_mapper
    mapper_names = {}
    for item in data:
        if len(item['username']) == 0:
            mapper_name = "Unknown"
        else:
            mapper_name = item['username'][0]
        mapper_names[item['user_id']] = mapper_name

    return mapper_names


@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
def main(args: DictConfig):
    torch.set_grad_enabled(False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, model_args, tokenizer = load_ckpt(args.checkpoint_path)
    model.eval().to(device)

    example = create_example_from_path(args.beatmap_path, args.audio_path, args.time, model_args, tokenizer, device, True)
    result: OsuClassifierOutput = model(**example)
    logits = result.logits

    # Print the top 100 mappers with confidences
    top_k = 100
    top_k_indices = logits[0].topk(top_k).indices
    top_k_confidences = logits[0].topk(top_k).values

    mapper_idx_id = {idx: ids for ids, idx in tokenizer.mapper_idx.items()}
    mapper_names = get_mapper_names(args.mappers_path)

    for idx, confidence in zip(top_k_indices, top_k_confidences):
        mapper_id = mapper_idx_id[idx.item()]
        mapper_name = mapper_names.get(mapper_id, "Unknown")
        print(f"Mapper: {mapper_name} ({mapper_id}) with confidence: {confidence.item()}")


if __name__ == "__main__":
    main()