Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import torch | |
| from pathlib import Path | |
| from audiotools import AudioSignal | |
| from tria.model.tria import TRIA | |
| from tria.pipelines.tokenizer import Tokenizer | |
| from tria.features import rhythm_features | |
| from functools import partial | |
| from pyharp.core import ModelCard, build_endpoint | |
| from pyharp.media.audio import load_audio, save_audio | |
| from pyharp.labels import LabelList | |
| # Global Config | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| N_OUTPUTS = 3 | |
| # Model Zoo | |
| MODEL_ZOO = { | |
| "small_musdb_moises_2b": { | |
| "checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt", | |
| "model_cfg": { | |
| "codebook_size": 1024, | |
| "n_codebooks": 9, | |
| "n_channels": 512, | |
| "n_feats": 2, | |
| "n_heads": 8, | |
| "n_layers": 12, | |
| "mult": 4, | |
| "p_dropout": 0.0, | |
| "bias": True, | |
| "max_len": 1000, | |
| "pos_enc": "rope", | |
| "qk_norm": True, | |
| "use_sdpa": True, | |
| "interp": "nearest", | |
| "share_emb": True, | |
| }, | |
| "tokenizer_cfg": {"name": "dac"}, | |
| "feature_cfg": { | |
| "sample_rate": 16_000, | |
| "n_bands": 2, | |
| "n_mels": 40, | |
| "window_length": 384, | |
| "hop_length": 192, | |
| "quantization_levels": 5, | |
| "slow_ma_ms": 200, | |
| "post_smooth_ms": 100, | |
| "legacy_normalize": False, | |
| "clamp_max": 50.0, | |
| "normalize_quantile": 0.98, | |
| }, | |
| "infer_cfg": { | |
| "top_p": 0.95, | |
| "top_k": None, | |
| "temp": 1.0, | |
| "mask_temp": 10.5, | |
| "iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4], | |
| "guidance_scale": 2.0, | |
| "causal_bias": 1.0, | |
| }, | |
| "max_duration": 6.0, | |
| }, | |
| } | |
| # Loaded model cache | |
| LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None) | |
| # Model loading | |
| def load_model_by_name(name: str): | |
| """Load a TRIA model by name (cached).""" | |
| if LOADED["name"] == name and LOADED["model"] is not None: | |
| return LOADED["model"] | |
| cfg = MODEL_ZOO[name] | |
| model = TRIA(**cfg["model_cfg"]) | |
| sd = torch.load(cfg["checkpoint"], map_location="cpu") | |
| model.load_state_dict(sd, strict=True) | |
| model.to(DEVICE).eval() | |
| tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE) | |
| feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {})) | |
| LOADED.update( | |
| dict( | |
| name=name, | |
| model=model, | |
| tokenizer=tokenizer, | |
| feature_fn=feat_fn, | |
| infer_cfg=cfg["infer_cfg"], | |
| sample_rate=tokenizer.sample_rate, | |
| max_duration=cfg["max_duration"], | |
| ) | |
| ) | |
| return model | |
| # Inference logic | |
| def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed): | |
| model = load_model_by_name(model_name) | |
| tokenizer = LOADED["tokenizer"] | |
| feat_fn = LOADED["feature_fn"] | |
| sample_rate = LOADED["sample_rate"] | |
| infer_cfg = LOADED["infer_cfg"] | |
| timbre_sig = load_audio(timbre_path).resample(sample_rate) | |
| rhythm_sig = load_audio(rhythm_path).resample(sample_rate) | |
| timbre_sig.ensure_max_of_audio() | |
| rhythm_sig.ensure_max_of_audio() | |
| prefix_dur = int(LOADED["max_duration"] / 3) | |
| timbre_tokens = tokenizer.encode(timbre_sig) | |
| rhythm_tokens = tokenizer.encode(rhythm_sig) | |
| tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1) | |
| n_batch, n_codebooks, n_frames = tokens.shape | |
| prefix_frames = timbre_tokens.tokens.shape[-1] | |
| feats = feat_fn(rhythm_sig) | |
| feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp) | |
| full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE) | |
| full_feats[..., prefix_frames:] = feats | |
| prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames | |
| buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1) | |
| feats_mask = ~prefix_mask | |
| outputs = [] | |
| for i in range(N_OUTPUTS): | |
| torch.manual_seed(seed + i) | |
| gen = model.inference( | |
| tokens.clone().to(DEVICE), | |
| full_feats.to(DEVICE), | |
| buffer_mask.clone().to(DEVICE), | |
| feats_mask.to(DEVICE), | |
| top_p=float(top_p), | |
| mask_temp=float(mask_temperature), | |
| iterations=infer_cfg["iterations"], | |
| guidance_scale=float(cfg_scale), | |
| )[..., prefix_frames:] | |
| rhythm_tokens.tokens = gen | |
| out_sig = tokenizer.decode(rhythm_tokens) | |
| out_sig.ensure_max_of_audio() | |
| output_path = f"tria_out_{i+1}.wav" | |
| save_audio(out_sig, output_path) | |
| path_i = output_path | |
| outputs.append(str(path_i)) | |
| return tuple(outputs) | |
| # PyHARP Metadata | |
| model_card = ModelCard( | |
| name="TRIA: The Rhythm In Anything", | |
| description=( | |
| "Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n " | |
| "Timbre Prompt: an example recording for the desired sound (e.g. drum sound) \n " | |
| "Rhythm Prompt: the sound gesture expressing the desired pattern (e.g. tapping or beatboxing) \n" | |
| "It generates 3 drum arrangements that match your groove and chosen timbre. " | |
| ), | |
| author="Patrick O'Reilly, Julia Barnett, Hugo Flores García, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo", | |
| tags=["tria", "rhythm-generation", "pyharp"], | |
| ) | |
| # Gradio and PyHARP Endpoint | |
| with gr.Blocks(title="TRIA") as demo: | |
| timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True) | |
| rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True) | |
| model_names = list(MODEL_ZOO.keys()) | |
| model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model") | |
| with gr.Row(): | |
| cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale", info=( | |
| "Controls how strongly the model follows your prompts/conditions.\n" | |
| "Low values: Model is more creative/random, High values: Model adheres strictly to prompts, less variation.")) | |
| top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P", info=( | |
| "Nucleus Sampling: Limits the pool of possible next tokens to the most likely ones.\n" | |
| "(Default behav.) only sample from tokens that make up the top 95% of probability")) | |
| mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature", info=( | |
| "In masked models, controls the randomness of which tokens get unmasked first.\n" | |
| "Low temperature: More deterministic, predictable unmasking order, High temperature: More random unmasking, can add variety")) | |
| seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed", info=( | |
| "Same seed and same inputs: Produces identical outputs (reproducible).\n" | |
| "Set a specific seed to get consistent results, or use fifferent seeds to try variations")) | |
| out1 = gr.Audio(type="filepath", label="Generated #1") | |
| out2 = gr.Audio(type="filepath", label="Generated #2") | |
| out3 = gr.Audio(type="filepath", label="Generated #3") | |
| app = build_endpoint( | |
| model_card=model_card, | |
| input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed], | |
| output_components=[out1, out2, out3], | |
| process_fn=generate_audio, | |
| ) | |
| demo.queue().launch(share=True, show_error=True) | |