osu_mapper / inference.py
legendrecalls's picture
Update inference.py
656d1fe verified
raw
history blame
4.24 kB
from pathlib import Path
import hydra
import torch
from omegaconf import DictConfig
from slider import Beatmap
from argparse import Namespace
from torch.serialization import add_safe_globals
# Trust custom objects in your checkpoint
add_safe_globals([Namespace])
from osudiffusion import DiT_models
from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipeline
from osuT5.tokenizer import Tokenizer
from osuT5.utils import get_model
def get_args_from_beatmap(args: DictConfig):
if args.beatmap_path is None or args.beatmap_path == "":
return
beatmap_path = Path(args.beatmap_path)
if not beatmap_path.is_file():
raise FileNotFoundError(f"Beatmap file {beatmap_path} not found.")
beatmap = Beatmap.from_path(beatmap_path)
args.audio_path = beatmap_path.parent / beatmap.audio_filename
args.output_path = beatmap_path.parent
args.bpm = beatmap.bpm_max()
args.offset = min(tp.offset.total_seconds() * 1000 for tp in beatmap.timing_points)
args.slider_multiplier = beatmap.slider_multiplier
args.title = beatmap.title
args.artist = beatmap.artist
args.beatmap_id = beatmap.beatmap_id if args.beatmap_id == -1 else args.beatmap_id
args.diffusion.style_id = beatmap.beatmap_id if args.diffusion.style_id == -1 else args.diffusion.style_id
args.difficulty = float(beatmap.stars()) if args.difficulty == -1 else args.difficulty
def find_model(ckpt_path, args: DictConfig, device):
assert Path(ckpt_path).exists(), f"Could not find DiT checkpoint at {ckpt_path}"
# Force full unpickling because we trust the checkpoint
checkpoint = torch.load(ckpt_path, weights_only=False, map_location=lambda storage, loc: storage)
if "ema" in checkpoint: # supports checkpoints from train.py
checkpoint = checkpoint["ema"]
model = DiT_models[args.diffusion.model](
num_classes=args.diffusion.num_classes,
context_size=19 - 3 + 128,
).to(device)
model.load_state_dict(checkpoint)
model.eval() # important!
return model
@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
def main(args: DictConfig):
get_args_from_beatmap(args)
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = Path(args.model_path)
# Trust the tokenizer checkpoint
model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device)
tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl", weights_only=False)
tokenizer = Tokenizer()
tokenizer.load_state_dict(tokenizer_state)
model = get_model(args, tokenizer)
model.load_state_dict(model_state)
model.eval()
model.to(device)
preprocessor = Preprocessor(args)
audio = preprocessor.load(args.audio_path)
sequences = preprocessor.segment(audio)
total_duration_ms = len(audio) / 16000 * 1000
args.total_duration_ms = total_duration_ms
generated_maps = []
generated_positions = []
diffs = []
if args.full_set:
for i in range(args.set_difficulties):
diffs.append(3 + i * (7 - 3) / (args.set_difficulties - 1))
print(diffs)
for diff in diffs:
print(f"Generating difficulty {diff}")
args.difficulty = diff
pipeline = Pipeline(args, tokenizer)
events = pipeline.generate(model, sequences)
generated_maps.append(events)
else:
pipeline = Pipeline(args, tokenizer)
events = pipeline.generate(model, sequences)
generated_maps.append(events)
if args.generate_positions:
model = find_model(args.diff_ckpt, args, device)
refine_model = find_model(args.diff_refine_ckpt, args, device) if len(args.diff_refine_ckpt) > 0 else None
diffusion_pipeline = DiffisionPipeline(args.diffusion)
for events in generated_maps:
events = diffusion_pipeline.generate(model, events, refine_model)
generated_positions.append(events)
else:
generated_positions = generated_maps
postprocessor = Postprocessor(args)
postprocessor.generate(generated_positions)
if __name__ == "__main__":
main()