File size: 2,910 Bytes
8337fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import random

import datasets
import numpy as np
import torch
from datasets import DatasetDict
from transformers import AutoConfig

from dataset import MusicDataset
from modelling_qwen3 import MAGEL


def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def resolve_device(device_arg: str) -> torch.device:
    if device_arg != "auto":
        return torch.device(device_arg)
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def move_batch_to_device(
    batch: dict[str, torch.Tensor], device: torch.device
) -> dict[str, torch.Tensor]:
    return {
        key: value.to(device) if torch.is_tensor(value) else value
        for key, value in batch.items()
    }

def load_music_dataset(
    dataset_path: str,
    split: str,
    tokenizer_path: str,
    num_audio_token: int = 16384,
    fps: int = 25,
    use_fast: bool = True,
) -> MusicDataset:
    hf = datasets.load_from_disk(dataset_path)
    if isinstance(hf, DatasetDict):
        if split not in hf:
            raise KeyError(f"Split not found: {split}")
        container = hf
    else:
        container = {split: hf}
    return MusicDataset(
        datasets=container,
        split=split,
        tokenizer_path=tokenizer_path,
        num_audio_token=num_audio_token,
        fps=fps,
        use_fast=use_fast,
    )


def load_magel_checkpoint(
    checkpoint_path: str,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
    attn_implementation: str = "sdpa",
) -> MAGEL:
    config = AutoConfig.from_pretrained(
        checkpoint_path,
        local_files_only=True,
    )

    model = MAGEL.from_pretrained(
        checkpoint_path,
        config=config,
        torch_dtype=dtype,
        attn_implementation=attn_implementation,
        local_files_only=True,
    )
    model.to(device=device)
    model.eval()
    return model


def maybe_compile_model(
    model,
    enabled: bool = False,
    mode: str = "reduce-overhead",
):
    if not enabled:
        setattr(model, "_magel_is_compiled", False)
        return model
    if not hasattr(torch, "compile"):
        raise RuntimeError("torch.compile is not available in this PyTorch build.")
    compiled_model = torch.compile(model, mode=mode)
    setattr(compiled_model, "_magel_is_compiled", True)
    return compiled_model


def maybe_mark_compile_step_begin(model) -> None:
    if not getattr(model, "_magel_is_compiled", False):
        return
    compiler_ns = getattr(torch, "compiler", None)
    if compiler_ns is None:
        return
    mark_step_begin = getattr(compiler_ns, "cudagraph_mark_step_begin", None)
    if mark_step_begin is None:
        return
    mark_step_begin()