Spaces:
Sleeping
Sleeping
File size: 7,618 Bytes
c9f87fa c51d47d c9f87fa c51d47d c9f87fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import spaces
import gradio as gr
import torch
from pathlib import Path
from audiotools import AudioSignal
from tria.model.tria import TRIA
from tria.pipelines.tokenizer import Tokenizer
from tria.features import rhythm_features
from functools import partial
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio, save_audio
from pyharp.labels import LabelList
# Global Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_OUTPUTS = 3
# Model Zoo
MODEL_ZOO = {
"small_musdb_moises_2b": {
"checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt",
"model_cfg": {
"codebook_size": 1024,
"n_codebooks": 9,
"n_channels": 512,
"n_feats": 2,
"n_heads": 8,
"n_layers": 12,
"mult": 4,
"p_dropout": 0.0,
"bias": True,
"max_len": 1000,
"pos_enc": "rope",
"qk_norm": True,
"use_sdpa": True,
"interp": "nearest",
"share_emb": True,
},
"tokenizer_cfg": {"name": "dac"},
"feature_cfg": {
"sample_rate": 16_000,
"n_bands": 2,
"n_mels": 40,
"window_length": 384,
"hop_length": 192,
"quantization_levels": 5,
"slow_ma_ms": 200,
"post_smooth_ms": 100,
"legacy_normalize": False,
"clamp_max": 50.0,
"normalize_quantile": 0.98,
},
"infer_cfg": {
"top_p": 0.95,
"top_k": None,
"temp": 1.0,
"mask_temp": 10.5,
"iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4],
"guidance_scale": 2.0,
"causal_bias": 1.0,
},
"max_duration": 6.0,
},
}
# Loaded model cache
LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None)
# Model loading
def load_model_by_name(name: str):
"""Load a TRIA model by name (cached)."""
if LOADED["name"] == name and LOADED["model"] is not None:
return LOADED["model"]
cfg = MODEL_ZOO[name]
model = TRIA(**cfg["model_cfg"])
sd = torch.load(cfg["checkpoint"], map_location="cpu")
model.load_state_dict(sd, strict=True)
model.to(DEVICE).eval()
tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE)
feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {}))
LOADED.update(
dict(
name=name,
model=model,
tokenizer=tokenizer,
feature_fn=feat_fn,
infer_cfg=cfg["infer_cfg"],
sample_rate=tokenizer.sample_rate,
max_duration=cfg["max_duration"],
)
)
return model
# Inference logic
@spaces.GPU
@torch.inference_mode()
def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed):
model = load_model_by_name(model_name)
tokenizer = LOADED["tokenizer"]
feat_fn = LOADED["feature_fn"]
sample_rate = LOADED["sample_rate"]
infer_cfg = LOADED["infer_cfg"]
timbre_sig = load_audio(timbre_path).resample(sample_rate)
rhythm_sig = load_audio(rhythm_path).resample(sample_rate)
timbre_sig.ensure_max_of_audio()
rhythm_sig.ensure_max_of_audio()
prefix_dur = int(LOADED["max_duration"] / 3)
timbre_tokens = tokenizer.encode(timbre_sig)
rhythm_tokens = tokenizer.encode(rhythm_sig)
tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
n_batch, n_codebooks, n_frames = tokens.shape
prefix_frames = timbre_tokens.tokens.shape[-1]
feats = feat_fn(rhythm_sig)
feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp)
full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE)
full_feats[..., prefix_frames:] = feats
prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames
buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
feats_mask = ~prefix_mask
outputs = []
for i in range(N_OUTPUTS):
torch.manual_seed(seed + i)
gen = model.inference(
tokens.clone().to(DEVICE),
full_feats.to(DEVICE),
buffer_mask.clone().to(DEVICE),
feats_mask.to(DEVICE),
top_p=float(top_p),
mask_temp=float(mask_temperature),
iterations=infer_cfg["iterations"],
guidance_scale=float(cfg_scale),
)[..., prefix_frames:]
rhythm_tokens.tokens = gen
out_sig = tokenizer.decode(rhythm_tokens)
out_sig.ensure_max_of_audio()
output_path = f"tria_out_{i+1}.wav"
save_audio(out_sig, output_path)
path_i = output_path
outputs.append(str(path_i))
return tuple(outputs)
# PyHARP Metadata
model_card = ModelCard(
name="TRIA: The Rhythm In Anything",
description=(
"Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n "
"Timbre Prompt: an example recording for the desired sound (e.g. drum sound) \n "
"Rhythm Prompt: the sound gesture expressing the desired pattern (e.g. tapping or beatboxing) \n"
"It generates 3 drum arrangements that match your groove and chosen timbre. "
),
author="Patrick O'Reilly, Julia Barnett, Hugo Flores García, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo",
tags=["tria", "rhythm-generation", "pyharp"],
)
# Gradio and PyHARP Endpoint
with gr.Blocks(title="TRIA") as demo:
timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True)
rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True)
model_names = list(MODEL_ZOO.keys())
model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model")
with gr.Row():
cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale", info=(
"Controls how strongly the model follows your prompts/conditions.\n"
"Low values: Model is more creative/random, High values: Model adheres strictly to prompts, less variation."))
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P", info=(
"Nucleus Sampling: Limits the pool of possible next tokens to the most likely ones.\n"
"(Default behav.) only sample from tokens that make up the top 95% of probability"))
mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature", info=(
"In masked models, controls the randomness of which tokens get unmasked first.\n"
"Low temperature: More deterministic, predictable unmasking order, High temperature: More random unmasking, can add variety"))
seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed", info=(
"Same seed and same inputs: Produces identical outputs (reproducible).\n"
"Set a specific seed to get consistent results, or use fifferent seeds to try variations"))
out1 = gr.Audio(type="filepath", label="Generated #1")
out2 = gr.Audio(type="filepath", label="Generated #2")
out3 = gr.Audio(type="filepath", label="Generated #3")
app = build_endpoint(
model_card=model_card,
input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed],
output_components=[out1, out2, out3],
process_fn=generate_audio,
)
demo.queue().launch(share=True, show_error=True)
|