import spaces import gradio as gr import torch from pathlib import Path from audiotools import AudioSignal from tria.model.tria import TRIA from tria.pipelines.tokenizer import Tokenizer from tria.features import rhythm_features from functools import partial from pyharp.core import ModelCard, build_endpoint from pyharp.media.audio import load_audio, save_audio from pyharp.labels import LabelList # Global Config DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") N_OUTPUTS = 3 # Model Zoo MODEL_ZOO = { "small_musdb_moises_2b": { "checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt", "model_cfg": { "codebook_size": 1024, "n_codebooks": 9, "n_channels": 512, "n_feats": 2, "n_heads": 8, "n_layers": 12, "mult": 4, "p_dropout": 0.0, "bias": True, "max_len": 1000, "pos_enc": "rope", "qk_norm": True, "use_sdpa": True, "interp": "nearest", "share_emb": True, }, "tokenizer_cfg": {"name": "dac"}, "feature_cfg": { "sample_rate": 16_000, "n_bands": 2, "n_mels": 40, "window_length": 384, "hop_length": 192, "quantization_levels": 5, "slow_ma_ms": 200, "post_smooth_ms": 100, "legacy_normalize": False, "clamp_max": 50.0, "normalize_quantile": 0.98, }, "infer_cfg": { "top_p": 0.95, "top_k": None, "temp": 1.0, "mask_temp": 10.5, "iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4], "guidance_scale": 2.0, "causal_bias": 1.0, }, "max_duration": 6.0, }, } # Loaded model cache LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None) # Model loading def load_model_by_name(name: str): """Load a TRIA model by name (cached).""" if LOADED["name"] == name and LOADED["model"] is not None: return LOADED["model"] cfg = MODEL_ZOO[name] model = TRIA(**cfg["model_cfg"]) sd = torch.load(cfg["checkpoint"], map_location="cpu") model.load_state_dict(sd, strict=True) model.to(DEVICE).eval() tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE) feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {})) LOADED.update( dict( name=name, model=model, tokenizer=tokenizer, feature_fn=feat_fn, infer_cfg=cfg["infer_cfg"], sample_rate=tokenizer.sample_rate, max_duration=cfg["max_duration"], ) ) return model # Inference logic @spaces.GPU @torch.inference_mode() def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed): model = load_model_by_name(model_name) tokenizer = LOADED["tokenizer"] feat_fn = LOADED["feature_fn"] sample_rate = LOADED["sample_rate"] infer_cfg = LOADED["infer_cfg"] timbre_sig = load_audio(timbre_path).resample(sample_rate) rhythm_sig = load_audio(rhythm_path).resample(sample_rate) timbre_sig.ensure_max_of_audio() rhythm_sig.ensure_max_of_audio() prefix_dur = int(LOADED["max_duration"] / 3) timbre_tokens = tokenizer.encode(timbre_sig) rhythm_tokens = tokenizer.encode(rhythm_sig) tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1) n_batch, n_codebooks, n_frames = tokens.shape prefix_frames = timbre_tokens.tokens.shape[-1] feats = feat_fn(rhythm_sig) feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp) full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE) full_feats[..., prefix_frames:] = feats prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1) feats_mask = ~prefix_mask outputs = [] for i in range(N_OUTPUTS): torch.manual_seed(seed + i) gen = model.inference( tokens.clone().to(DEVICE), full_feats.to(DEVICE), buffer_mask.clone().to(DEVICE), feats_mask.to(DEVICE), top_p=float(top_p), mask_temp=float(mask_temperature), iterations=infer_cfg["iterations"], guidance_scale=float(cfg_scale), )[..., prefix_frames:] rhythm_tokens.tokens = gen out_sig = tokenizer.decode(rhythm_tokens) out_sig.ensure_max_of_audio() output_path = f"tria_out_{i+1}.wav" save_audio(out_sig, output_path) path_i = output_path outputs.append(str(path_i)) return tuple(outputs) # PyHARP Metadata model_card = ModelCard( name="TRIA: The Rhythm In Anything", description=( "Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n " "Timbre Prompt: an example recording for the desired sound (e.g. drum sound) \n " "Rhythm Prompt: the sound gesture expressing the desired pattern (e.g. tapping or beatboxing) \n" "It generates 3 drum arrangements that match your groove and chosen timbre. " ), author="Patrick O'Reilly, Julia Barnett, Hugo Flores GarcĂ­a, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo", tags=["tria", "rhythm-generation", "pyharp"], ) # Gradio and PyHARP Endpoint with gr.Blocks(title="TRIA") as demo: timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True) rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True) model_names = list(MODEL_ZOO.keys()) model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model") with gr.Row(): cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale", info=( "Controls how strongly the model follows your prompts/conditions.\n" "Low values: Model is more creative/random, High values: Model adheres strictly to prompts, less variation.")) top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P", info=( "Nucleus Sampling: Limits the pool of possible next tokens to the most likely ones.\n" "(Default behav.) only sample from tokens that make up the top 95% of probability")) mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature", info=( "In masked models, controls the randomness of which tokens get unmasked first.\n" "Low temperature: More deterministic, predictable unmasking order, High temperature: More random unmasking, can add variety")) seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed", info=( "Same seed and same inputs: Produces identical outputs (reproducible).\n" "Set a specific seed to get consistent results, or use fifferent seeds to try variations")) out1 = gr.Audio(type="filepath", label="Generated #1") out2 = gr.Audio(type="filepath", label="Generated #2") out3 = gr.Audio(type="filepath", label="Generated #3") app = build_endpoint( model_card=model_card, input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed], output_components=[out1, out2, out3], process_fn=generate_audio, ) demo.queue().launch(share=True, show_error=True)