TRIA / app.py
saumya-pailwan's picture
description and hints update
c51d47d verified
import spaces
import gradio as gr
import torch
from pathlib import Path
from audiotools import AudioSignal
from tria.model.tria import TRIA
from tria.pipelines.tokenizer import Tokenizer
from tria.features import rhythm_features
from functools import partial
from pyharp.core import ModelCard, build_endpoint
from pyharp.media.audio import load_audio, save_audio
from pyharp.labels import LabelList
# Global Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_OUTPUTS = 3
# Model Zoo
MODEL_ZOO = {
"small_musdb_moises_2b": {
"checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt",
"model_cfg": {
"codebook_size": 1024,
"n_codebooks": 9,
"n_channels": 512,
"n_feats": 2,
"n_heads": 8,
"n_layers": 12,
"mult": 4,
"p_dropout": 0.0,
"bias": True,
"max_len": 1000,
"pos_enc": "rope",
"qk_norm": True,
"use_sdpa": True,
"interp": "nearest",
"share_emb": True,
},
"tokenizer_cfg": {"name": "dac"},
"feature_cfg": {
"sample_rate": 16_000,
"n_bands": 2,
"n_mels": 40,
"window_length": 384,
"hop_length": 192,
"quantization_levels": 5,
"slow_ma_ms": 200,
"post_smooth_ms": 100,
"legacy_normalize": False,
"clamp_max": 50.0,
"normalize_quantile": 0.98,
},
"infer_cfg": {
"top_p": 0.95,
"top_k": None,
"temp": 1.0,
"mask_temp": 10.5,
"iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4],
"guidance_scale": 2.0,
"causal_bias": 1.0,
},
"max_duration": 6.0,
},
}
# Loaded model cache
LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None)
# Model loading
def load_model_by_name(name: str):
"""Load a TRIA model by name (cached)."""
if LOADED["name"] == name and LOADED["model"] is not None:
return LOADED["model"]
cfg = MODEL_ZOO[name]
model = TRIA(**cfg["model_cfg"])
sd = torch.load(cfg["checkpoint"], map_location="cpu")
model.load_state_dict(sd, strict=True)
model.to(DEVICE).eval()
tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE)
feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {}))
LOADED.update(
dict(
name=name,
model=model,
tokenizer=tokenizer,
feature_fn=feat_fn,
infer_cfg=cfg["infer_cfg"],
sample_rate=tokenizer.sample_rate,
max_duration=cfg["max_duration"],
)
)
return model
# Inference logic
@spaces.GPU
@torch.inference_mode()
def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed):
model = load_model_by_name(model_name)
tokenizer = LOADED["tokenizer"]
feat_fn = LOADED["feature_fn"]
sample_rate = LOADED["sample_rate"]
infer_cfg = LOADED["infer_cfg"]
timbre_sig = load_audio(timbre_path).resample(sample_rate)
rhythm_sig = load_audio(rhythm_path).resample(sample_rate)
timbre_sig.ensure_max_of_audio()
rhythm_sig.ensure_max_of_audio()
prefix_dur = int(LOADED["max_duration"] / 3)
timbre_tokens = tokenizer.encode(timbre_sig)
rhythm_tokens = tokenizer.encode(rhythm_sig)
tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
n_batch, n_codebooks, n_frames = tokens.shape
prefix_frames = timbre_tokens.tokens.shape[-1]
feats = feat_fn(rhythm_sig)
feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp)
full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE)
full_feats[..., prefix_frames:] = feats
prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames
buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
feats_mask = ~prefix_mask
outputs = []
for i in range(N_OUTPUTS):
torch.manual_seed(seed + i)
gen = model.inference(
tokens.clone().to(DEVICE),
full_feats.to(DEVICE),
buffer_mask.clone().to(DEVICE),
feats_mask.to(DEVICE),
top_p=float(top_p),
mask_temp=float(mask_temperature),
iterations=infer_cfg["iterations"],
guidance_scale=float(cfg_scale),
)[..., prefix_frames:]
rhythm_tokens.tokens = gen
out_sig = tokenizer.decode(rhythm_tokens)
out_sig.ensure_max_of_audio()
output_path = f"tria_out_{i+1}.wav"
save_audio(out_sig, output_path)
path_i = output_path
outputs.append(str(path_i))
return tuple(outputs)
# PyHARP Metadata
model_card = ModelCard(
name="TRIA: The Rhythm In Anything",
description=(
"Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n "
"Timbre Prompt: an example recording for the desired sound (e.g. drum sound) \n "
"Rhythm Prompt: the sound gesture expressing the desired pattern (e.g. tapping or beatboxing) \n"
"It generates 3 drum arrangements that match your groove and chosen timbre. "
),
author="Patrick O'Reilly, Julia Barnett, Hugo Flores García, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo",
tags=["tria", "rhythm-generation", "pyharp"],
)
# Gradio and PyHARP Endpoint
with gr.Blocks(title="TRIA") as demo:
timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True)
rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True)
model_names = list(MODEL_ZOO.keys())
model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model")
with gr.Row():
cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale", info=(
"Controls how strongly the model follows your prompts/conditions.\n"
"Low values: Model is more creative/random, High values: Model adheres strictly to prompts, less variation."))
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P", info=(
"Nucleus Sampling: Limits the pool of possible next tokens to the most likely ones.\n"
"(Default behav.) only sample from tokens that make up the top 95% of probability"))
mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature", info=(
"In masked models, controls the randomness of which tokens get unmasked first.\n"
"Low temperature: More deterministic, predictable unmasking order, High temperature: More random unmasking, can add variety"))
seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed", info=(
"Same seed and same inputs: Produces identical outputs (reproducible).\n"
"Set a specific seed to get consistent results, or use fifferent seeds to try variations"))
out1 = gr.Audio(type="filepath", label="Generated #1")
out2 = gr.Audio(type="filepath", label="Generated #2")
out3 = gr.Audio(type="filepath", label="Generated #3")
app = build_endpoint(
model_card=model_card,
input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed],
output_components=[out1, out2, out3],
process_fn=generate_audio,
)
demo.queue().launch(share=True, show_error=True)