File size: 7,618 Bytes
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c51d47d
 
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c51d47d
 
 
 
 
 
 
 
 
 
 
 
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import spaces
import gradio as gr
import torch
from pathlib import Path
from audiotools import AudioSignal
from tria.model.tria import TRIA
from tria.pipelines.tokenizer import Tokenizer
from tria.features import rhythm_features
from functools import partial
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio, save_audio
from pyharp.labels import LabelList

# Global Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_OUTPUTS = 3

# Model Zoo
MODEL_ZOO = {
    "small_musdb_moises_2b": {
        "checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt",
        "model_cfg": {
            "codebook_size": 1024,
            "n_codebooks": 9,
            "n_channels": 512,
            "n_feats": 2,
            "n_heads": 8,
            "n_layers": 12,
            "mult": 4,
            "p_dropout": 0.0,
            "bias": True,
            "max_len": 1000,
            "pos_enc": "rope",
            "qk_norm": True,
            "use_sdpa": True,
            "interp": "nearest",
            "share_emb": True,
        },
        "tokenizer_cfg": {"name": "dac"},
        "feature_cfg": {
            "sample_rate": 16_000,
            "n_bands": 2,
            "n_mels": 40,
            "window_length": 384,
            "hop_length": 192,
            "quantization_levels": 5,
            "slow_ma_ms": 200,
            "post_smooth_ms": 100,
            "legacy_normalize": False,
            "clamp_max": 50.0,
            "normalize_quantile": 0.98,
        },
        "infer_cfg": {
            "top_p": 0.95,
            "top_k": None,
            "temp": 1.0,
            "mask_temp": 10.5,
            "iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4],
            "guidance_scale": 2.0,
            "causal_bias": 1.0,
        },
        "max_duration": 6.0,
    },
}

# Loaded model cache
LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None)

# Model loading
def load_model_by_name(name: str):
    """Load a TRIA model by name (cached)."""
    if LOADED["name"] == name and LOADED["model"] is not None:
        return LOADED["model"]

    cfg = MODEL_ZOO[name]
    model = TRIA(**cfg["model_cfg"])
    sd = torch.load(cfg["checkpoint"], map_location="cpu")
    model.load_state_dict(sd, strict=True)
    model.to(DEVICE).eval()

    tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE)
    feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {}))

    LOADED.update(
        dict(
            name=name,
            model=model,
            tokenizer=tokenizer,
            feature_fn=feat_fn,
            infer_cfg=cfg["infer_cfg"],
            sample_rate=tokenizer.sample_rate,
            max_duration=cfg["max_duration"],
        )
    )
    return model


# Inference logic
@spaces.GPU
@torch.inference_mode()
def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed):
    model = load_model_by_name(model_name)
    tokenizer = LOADED["tokenizer"]
    feat_fn = LOADED["feature_fn"]
    sample_rate = LOADED["sample_rate"]
    infer_cfg = LOADED["infer_cfg"]

    timbre_sig = load_audio(timbre_path).resample(sample_rate)
    rhythm_sig = load_audio(rhythm_path).resample(sample_rate)
    timbre_sig.ensure_max_of_audio()
    rhythm_sig.ensure_max_of_audio()

    prefix_dur = int(LOADED["max_duration"] / 3)
    timbre_tokens = tokenizer.encode(timbre_sig)
    rhythm_tokens = tokenizer.encode(rhythm_sig)
    tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
    n_batch, n_codebooks, n_frames = tokens.shape
    prefix_frames = timbre_tokens.tokens.shape[-1]

    feats = feat_fn(rhythm_sig)
    feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp)
    full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE)
    full_feats[..., prefix_frames:] = feats

    prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames
    buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
    feats_mask = ~prefix_mask

    outputs = []
    for i in range(N_OUTPUTS):
        torch.manual_seed(seed + i)
        gen = model.inference(
            tokens.clone().to(DEVICE),
            full_feats.to(DEVICE),
            buffer_mask.clone().to(DEVICE),
            feats_mask.to(DEVICE),
            top_p=float(top_p),
            mask_temp=float(mask_temperature),
            iterations=infer_cfg["iterations"],
            guidance_scale=float(cfg_scale),
        )[..., prefix_frames:]

        rhythm_tokens.tokens = gen
        out_sig = tokenizer.decode(rhythm_tokens)
        out_sig.ensure_max_of_audio()
        output_path = f"tria_out_{i+1}.wav"
        save_audio(out_sig, output_path)
        path_i = output_path
        outputs.append(str(path_i))
    return tuple(outputs)


# PyHARP Metadata
model_card = ModelCard(
    name="TRIA: The Rhythm In Anything",
    description=(
        "Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n "
        "Timbre Prompt: an example recording for the desired sound (e.g. drum sound) \n "
        "Rhythm Prompt: the sound gesture expressing the desired pattern (e.g. tapping or beatboxing) \n"
        "It generates 3 drum arrangements that match your groove and chosen timbre. "
    ),
    author="Patrick O'Reilly, Julia Barnett, Hugo Flores García, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo",
    tags=["tria", "rhythm-generation", "pyharp"],
)


# Gradio and PyHARP Endpoint
with gr.Blocks(title="TRIA") as demo:
    timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True)
    rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True)

    model_names = list(MODEL_ZOO.keys())
    model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model")

    with gr.Row():
        cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale", info=(
        "Controls how strongly the model follows your prompts/conditions.\n"
        "Low values: Model is more creative/random, High values: Model adheres strictly to prompts, less variation."))
        top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P", info=(
        "Nucleus Sampling: Limits the pool of possible next tokens to the most likely ones.\n"
        "(Default behav.) only sample from tokens that make up the top 95% of probability"))
        mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature", info=(
        "In masked models, controls the randomness of which tokens get unmasked first.\n"
        "Low temperature: More deterministic, predictable unmasking order, High temperature: More random unmasking, can add variety"))
        seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed", info=(
        "Same seed and same inputs: Produces identical outputs (reproducible).\n"
        "Set a specific seed to get consistent results, or use fifferent seeds to try variations"))

    out1 = gr.Audio(type="filepath", label="Generated #1")
    out2 = gr.Audio(type="filepath", label="Generated #2")
    out3 = gr.Audio(type="filepath", label="Generated #3")

    app = build_endpoint(
        model_card=model_card,
        input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed],
        output_components=[out1, out2, out3],
        process_fn=generate_audio,
    )

demo.queue().launch(share=True, show_error=True)