BeatHeritage-v1 / classifier /classify.py
fourmansyah's picture
Duplicate from hongminh54/BeatHeritage-v1
12a8e0f
import json
from pathlib import Path
import numpy.typing as npt
import hydra
import torch
from omegaconf import DictConfig
from slider import Beatmap
from torch.utils.data import IterableDataset
from classifier.libs.dataset import OsuParser
from classifier.libs.dataset.data_utils import load_audio_file
from classifier.libs.dataset.ors_dataset import STEPS_PER_MILLISECOND
from classifier.libs.model.model import OsuClassifierOutput
from classifier.libs.tokenizer import Tokenizer, Event, EventType
from classifier.libs.utils import load_ckpt
def iterate_examples(
beatmap: Beatmap,
audio: npt.NDArray,
model_args: DictConfig,
tokenizer: Tokenizer,
device: torch.device
):
frame_seq_len = model_args.data.src_seq_len - 1
frame_size = model_args.data.hop_length
sample_rate = model_args.data.sample_rate
samples_per_sequence = frame_seq_len * frame_size
parser = OsuParser(model_args, tokenizer)
events, event_times = parser.parse(beatmap)
for sample in range(0, len(audio) - samples_per_sequence, samples_per_sequence):
example = create_example(events, event_times, audio, sample / sample_rate, model_args, tokenizer, device)
yield example
class ExampleDataset(IterableDataset):
def __init__(self, beatmap, audio, classifier_args, classifier_tokenizer, device):
self.beatmap = beatmap
self.audio = audio
self.classifier_args = classifier_args
self.classifier_tokenizer = classifier_tokenizer
self.device = device
def __iter__(self):
return iterate_examples(
self.beatmap,
self.audio,
self.classifier_args,
self.classifier_tokenizer,
self.device
)
def create_example(
events: list[Event],
event_times: list[float],
audio: npt.NDArray,
time: float,
model_args: DictConfig,
tokenizer: Tokenizer,
device: torch.device,
unsqueeze: bool = False,
):
frame_seq_len = model_args.data.src_seq_len - 1
frame_size = model_args.data.hop_length
sample_rate = model_args.data.sample_rate
samples_per_sequence = frame_seq_len * frame_size
sequence_duration = samples_per_sequence / sample_rate
# Get audio frames
frame_start = int(time * sample_rate)
frames = audio[frame_start:frame_start + samples_per_sequence]
frames = torch.from_numpy(frames).to(torch.float32).to(device)
# Get the events between time and time + sequence_duration
events = [event for event, event_time in zip(events, event_times) if
time <= event_time / 1000 < time + sequence_duration]
# Normalize time shifts
for i, event in enumerate(events):
if event.type == EventType.TIME_SHIFT:
events[i] = Event(EventType.TIME_SHIFT, int((event.value - time * 1000) * STEPS_PER_MILLISECOND))
# Tokenize the events
tokens = torch.full((model_args.data.tgt_seq_len,), tokenizer.pad_id, dtype=torch.long)
for i in range(min(len(events), model_args.data.tgt_seq_len)):
tokens[i] = tokenizer.encode(events[i])
tokens = tokens.to(device)
if unsqueeze:
tokens = tokens.unsqueeze(0)
frames = frames.unsqueeze(0)
return {
"decoder_input_ids": tokens,
"decoder_attention_mask": tokens != tokenizer.pad_id,
"frames": frames,
}
def create_example_from_path(
beatmap_path: str,
audio_path: str,
time: float,
model_args: DictConfig,
tokenizer: Tokenizer,
device: torch.device,
unsqueeze: bool = False,
):
sample_rate = model_args.data.sample_rate
beatmap_path = Path(beatmap_path)
beatmap = Beatmap.from_path(beatmap_path)
# Get audio frames
if audio_path == '':
audio_path = beatmap_path.parent / beatmap.audio_filename
audio = load_audio_file(audio_path, sample_rate)
parser = OsuParser(model_args, tokenizer)
events, event_times = parser.parse(beatmap)
return create_example(events, event_times, audio, time, model_args, tokenizer, device, unsqueeze)
def get_mapper_names(path: str):
path = Path(path)
# Load JSON data from file
with open(path, 'r') as file:
data = json.load(file)
# Populate beatmap_mapper
mapper_names = {}
for item in data:
if len(item['username']) == 0:
mapper_name = "Unknown"
else:
mapper_name = item['username'][0]
mapper_names[item['user_id']] = mapper_name
return mapper_names
@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
def main(args: DictConfig):
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, model_args, tokenizer = load_ckpt(args.checkpoint_path)
model.eval().to(device)
example = create_example_from_path(args.beatmap_path, args.audio_path, args.time, model_args, tokenizer, device, True)
result: OsuClassifierOutput = model(**example)
logits = result.logits
# Print the top 100 mappers with confidences
top_k = 100
top_k_indices = logits[0].topk(top_k).indices
top_k_confidences = logits[0].topk(top_k).values
mapper_idx_id = {idx: ids for ids, idx in tokenizer.mapper_idx.items()}
mapper_names = get_mapper_names(args.mappers_path)
for idx, confidence in zip(top_k_indices, top_k_confidences):
mapper_id = mapper_idx_id[idx.item()]
mapper_name = mapper_names.get(mapper_id, "Unknown")
print(f"Mapper: {mapper_name} ({mapper_id}) with confidence: {confidence.item()}")
if __name__ == "__main__":
main()