File size: 4,244 Bytes
656d1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pathlib import Path
import hydra
import torch
from omegaconf import DictConfig
from slider import Beatmap
from argparse import Namespace
from torch.serialization import add_safe_globals

# Trust custom objects in your checkpoint
add_safe_globals([Namespace])

from osudiffusion import DiT_models
from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipeline
from osuT5.tokenizer import Tokenizer
from osuT5.utils import get_model


def get_args_from_beatmap(args: DictConfig):
    if args.beatmap_path is None or args.beatmap_path == "":
        return

    beatmap_path = Path(args.beatmap_path)

    if not beatmap_path.is_file():
        raise FileNotFoundError(f"Beatmap file {beatmap_path} not found.")

    beatmap = Beatmap.from_path(beatmap_path)
    args.audio_path = beatmap_path.parent / beatmap.audio_filename
    args.output_path = beatmap_path.parent
    args.bpm = beatmap.bpm_max()
    args.offset = min(tp.offset.total_seconds() * 1000 for tp in beatmap.timing_points)
    args.slider_multiplier = beatmap.slider_multiplier
    args.title = beatmap.title
    args.artist = beatmap.artist
    args.beatmap_id = beatmap.beatmap_id if args.beatmap_id == -1 else args.beatmap_id
    args.diffusion.style_id = beatmap.beatmap_id if args.diffusion.style_id == -1 else args.diffusion.style_id
    args.difficulty = float(beatmap.stars()) if args.difficulty == -1 else args.difficulty


def find_model(ckpt_path, args: DictConfig, device):
    assert Path(ckpt_path).exists(), f"Could not find DiT checkpoint at {ckpt_path}"

    # Force full unpickling because we trust the checkpoint
    checkpoint = torch.load(ckpt_path, weights_only=False, map_location=lambda storage, loc: storage)
    if "ema" in checkpoint:  # supports checkpoints from train.py
        checkpoint = checkpoint["ema"]

    model = DiT_models[args.diffusion.model](
        num_classes=args.diffusion.num_classes,
        context_size=19 - 3 + 128,
    ).to(device)
    model.load_state_dict(checkpoint)
    model.eval()  # important!
    return model


@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
def main(args: DictConfig):
    get_args_from_beatmap(args)

    torch.set_grad_enabled(False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt_path = Path(args.model_path)

    # Trust the tokenizer checkpoint
    model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device)
    tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl", weights_only=False)

    tokenizer = Tokenizer()
    tokenizer.load_state_dict(tokenizer_state)

    model = get_model(args, tokenizer)
    model.load_state_dict(model_state)
    model.eval()
    model.to(device)
    
    preprocessor = Preprocessor(args)
    audio = preprocessor.load(args.audio_path)
    sequences = preprocessor.segment(audio)
    total_duration_ms = len(audio) / 16000 * 1000
    args.total_duration_ms = total_duration_ms

    generated_maps = []
    generated_positions = []
    diffs = []

    if args.full_set:
        for i in range(args.set_difficulties):
            diffs.append(3 + i * (7 - 3) / (args.set_difficulties - 1))

        print(diffs)
        for diff in diffs:
            print(f"Generating difficulty {diff}")
            args.difficulty = diff
            pipeline = Pipeline(args, tokenizer)
            events = pipeline.generate(model, sequences)
            generated_maps.append(events)
    else:
        pipeline = Pipeline(args, tokenizer)
        events = pipeline.generate(model, sequences)
        generated_maps.append(events)

    if args.generate_positions:
        model = find_model(args.diff_ckpt, args, device)
        refine_model = find_model(args.diff_refine_ckpt, args, device) if len(args.diff_refine_ckpt) > 0 else None
        diffusion_pipeline = DiffisionPipeline(args.diffusion)
        for events in generated_maps:
            events = diffusion_pipeline.generate(model, events, refine_model)
            generated_positions.append(events)
    else:
        generated_positions = generated_maps

    postprocessor = Postprocessor(args)
    postprocessor.generate(generated_positions)


if __name__ == "__main__":
    main()