Spaces:
Sleeping
Sleeping
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .DS_Store +0 -0
- .gradio/certificate.pem +31 -0
- README.md +1 -2
- app.py +192 -0
- pretrained/.gitignore +0 -0
- pretrained/tokenizer/dac/dac_44.1kHz_7.7kbps.pt +3 -0
- pretrained/tria/small_musdb_moises_2b/80000/extras.pt +3 -0
- pretrained/tria/small_musdb_moises_2b/80000/model.pt +3 -0
- pretrained/tria/small_musdb_moises_2b/best/extras.pt +3 -0
- pretrained/tria/small_musdb_moises_2b/best/model.pt +3 -0
- requirements.txt +11 -0
- tria/__init__.py +6 -0
- tria/__pycache__/__init__.cpython-310.pyc +0 -0
- tria/__pycache__/constants.cpython-310.pyc +0 -0
- tria/__pycache__/features.cpython-310.pyc +0 -0
- tria/__pycache__/util.cpython-310.pyc +0 -0
- tria/constants.py +11 -0
- tria/data/__init__.py +0 -0
- tria/data/dataset.py +280 -0
- tria/data/preprocess.py +124 -0
- tria/features.py +187 -0
- tria/model/__init__.py +1 -0
- tria/model/__pycache__/__init__.cpython-310.pyc +0 -0
- tria/model/__pycache__/mask.cpython-310.pyc +0 -0
- tria/model/__pycache__/sample.cpython-310.pyc +0 -0
- tria/model/__pycache__/tria.cpython-310.pyc +0 -0
- tria/model/mask.py +263 -0
- tria/model/sample.py +168 -0
- tria/model/tria.py +344 -0
- tria/nn/__init__.py +0 -0
- tria/nn/__pycache__/__init__.cpython-310.pyc +0 -0
- tria/nn/__pycache__/attention.cpython-310.pyc +0 -0
- tria/nn/__pycache__/norm.cpython-310.pyc +0 -0
- tria/nn/__pycache__/pos_enc.cpython-310.pyc +0 -0
- tria/nn/__pycache__/transformer.cpython-310.pyc +0 -0
- tria/nn/attention.py +280 -0
- tria/nn/norm.py +53 -0
- tria/nn/pos_enc.py +101 -0
- tria/nn/transformer.py +259 -0
- tria/pipelines/__init__.py +0 -0
- tria/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
- tria/pipelines/tokenizer/__init__.py +2 -0
- tria/pipelines/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
- tria/pipelines/tokenizer/__pycache__/tokenizer.cpython-310.pyc +0 -0
- tria/pipelines/tokenizer/dac/LICENSE +21 -0
- tria/pipelines/tokenizer/dac/__init__.py +1 -0
- tria/pipelines/tokenizer/dac/__pycache__/__init__.cpython-310.pyc +0 -0
- tria/pipelines/tokenizer/dac/__pycache__/dac.cpython-310.pyc +0 -0
- tria/pipelines/tokenizer/dac/__pycache__/modules.cpython-310.pyc +0 -0
- tria/pipelines/tokenizer/dac/dac.py +203 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
CHANGED
|
@@ -8,6 +8,5 @@ sdk_version: 5.49.1
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 11 |
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
short_description: Audio Prompted Drums Generation
|
| 12 |
---
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from audiotools import AudioSignal
|
| 6 |
+
from tria.model.tria import TRIA
|
| 7 |
+
from tria.pipelines.tokenizer import Tokenizer
|
| 8 |
+
from tria.features import rhythm_features
|
| 9 |
+
from functools import partial
|
| 10 |
+
from pyharp.core import ModelCard, build_endpoint
|
| 11 |
+
from pyharp.media.audio import load_audio, save_audio
|
| 12 |
+
from pyharp.labels import LabelList
|
| 13 |
+
|
| 14 |
+
# Global Config
|
| 15 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
N_OUTPUTS = 3
|
| 17 |
+
|
| 18 |
+
# Model Zoo
|
| 19 |
+
MODEL_ZOO = {
|
| 20 |
+
"small_musdb_moises_2b": {
|
| 21 |
+
"checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt",
|
| 22 |
+
"model_cfg": {
|
| 23 |
+
"codebook_size": 1024,
|
| 24 |
+
"n_codebooks": 9,
|
| 25 |
+
"n_channels": 512,
|
| 26 |
+
"n_feats": 2,
|
| 27 |
+
"n_heads": 8,
|
| 28 |
+
"n_layers": 12,
|
| 29 |
+
"mult": 4,
|
| 30 |
+
"p_dropout": 0.0,
|
| 31 |
+
"bias": True,
|
| 32 |
+
"max_len": 1000,
|
| 33 |
+
"pos_enc": "rope",
|
| 34 |
+
"qk_norm": True,
|
| 35 |
+
"use_sdpa": True,
|
| 36 |
+
"interp": "nearest",
|
| 37 |
+
"share_emb": True,
|
| 38 |
+
},
|
| 39 |
+
"tokenizer_cfg": {"name": "dac"},
|
| 40 |
+
"feature_cfg": {
|
| 41 |
+
"sample_rate": 16_000,
|
| 42 |
+
"n_bands": 2,
|
| 43 |
+
"n_mels": 40,
|
| 44 |
+
"window_length": 384,
|
| 45 |
+
"hop_length": 192,
|
| 46 |
+
"quantization_levels": 5,
|
| 47 |
+
"slow_ma_ms": 200,
|
| 48 |
+
"post_smooth_ms": 100,
|
| 49 |
+
"legacy_normalize": False,
|
| 50 |
+
"clamp_max": 50.0,
|
| 51 |
+
"normalize_quantile": 0.98,
|
| 52 |
+
},
|
| 53 |
+
"infer_cfg": {
|
| 54 |
+
"top_p": 0.95,
|
| 55 |
+
"top_k": None,
|
| 56 |
+
"temp": 1.0,
|
| 57 |
+
"mask_temp": 10.5,
|
| 58 |
+
"iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4],
|
| 59 |
+
"guidance_scale": 2.0,
|
| 60 |
+
"causal_bias": 1.0,
|
| 61 |
+
},
|
| 62 |
+
"max_duration": 6.0,
|
| 63 |
+
},
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Loaded model cache
|
| 67 |
+
LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None)
|
| 68 |
+
|
| 69 |
+
# Model loading
|
| 70 |
+
def load_model_by_name(name: str):
|
| 71 |
+
"""Load a TRIA model by name (cached)."""
|
| 72 |
+
if LOADED["name"] == name and LOADED["model"] is not None:
|
| 73 |
+
return LOADED["model"]
|
| 74 |
+
|
| 75 |
+
cfg = MODEL_ZOO[name]
|
| 76 |
+
model = TRIA(**cfg["model_cfg"])
|
| 77 |
+
sd = torch.load(cfg["checkpoint"], map_location="cpu")
|
| 78 |
+
model.load_state_dict(sd, strict=True)
|
| 79 |
+
model.to(DEVICE).eval()
|
| 80 |
+
|
| 81 |
+
tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE)
|
| 82 |
+
feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {}))
|
| 83 |
+
|
| 84 |
+
LOADED.update(
|
| 85 |
+
dict(
|
| 86 |
+
name=name,
|
| 87 |
+
model=model,
|
| 88 |
+
tokenizer=tokenizer,
|
| 89 |
+
feature_fn=feat_fn,
|
| 90 |
+
infer_cfg=cfg["infer_cfg"],
|
| 91 |
+
sample_rate=tokenizer.sample_rate,
|
| 92 |
+
max_duration=cfg["max_duration"],
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
return model
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Inference logic
|
| 99 |
+
@spaces.GPU
|
| 100 |
+
@torch.inference_mode()
|
| 101 |
+
def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed):
|
| 102 |
+
model = load_model_by_name(model_name)
|
| 103 |
+
tokenizer = LOADED["tokenizer"]
|
| 104 |
+
feat_fn = LOADED["feature_fn"]
|
| 105 |
+
sample_rate = LOADED["sample_rate"]
|
| 106 |
+
infer_cfg = LOADED["infer_cfg"]
|
| 107 |
+
|
| 108 |
+
timbre_sig = load_audio(timbre_path).resample(sample_rate)
|
| 109 |
+
rhythm_sig = load_audio(rhythm_path).resample(sample_rate)
|
| 110 |
+
timbre_sig.ensure_max_of_audio()
|
| 111 |
+
rhythm_sig.ensure_max_of_audio()
|
| 112 |
+
|
| 113 |
+
prefix_dur = int(LOADED["max_duration"] / 3)
|
| 114 |
+
timbre_tokens = tokenizer.encode(timbre_sig)
|
| 115 |
+
rhythm_tokens = tokenizer.encode(rhythm_sig)
|
| 116 |
+
tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
|
| 117 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 118 |
+
prefix_frames = timbre_tokens.tokens.shape[-1]
|
| 119 |
+
|
| 120 |
+
feats = feat_fn(rhythm_sig)
|
| 121 |
+
feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp)
|
| 122 |
+
full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE)
|
| 123 |
+
full_feats[..., prefix_frames:] = feats
|
| 124 |
+
|
| 125 |
+
prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames
|
| 126 |
+
buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
|
| 127 |
+
feats_mask = ~prefix_mask
|
| 128 |
+
|
| 129 |
+
outputs = []
|
| 130 |
+
for i in range(N_OUTPUTS):
|
| 131 |
+
torch.manual_seed(seed + i)
|
| 132 |
+
gen = model.inference(
|
| 133 |
+
tokens.clone().to(DEVICE),
|
| 134 |
+
full_feats.to(DEVICE),
|
| 135 |
+
buffer_mask.clone().to(DEVICE),
|
| 136 |
+
feats_mask.to(DEVICE),
|
| 137 |
+
top_p=float(top_p),
|
| 138 |
+
mask_temp=float(mask_temperature),
|
| 139 |
+
iterations=infer_cfg["iterations"],
|
| 140 |
+
guidance_scale=float(cfg_scale),
|
| 141 |
+
)[..., prefix_frames:]
|
| 142 |
+
|
| 143 |
+
rhythm_tokens.tokens = gen
|
| 144 |
+
out_sig = tokenizer.decode(rhythm_tokens)
|
| 145 |
+
out_sig.ensure_max_of_audio()
|
| 146 |
+
output_path = f"tria_out_{i+1}.wav"
|
| 147 |
+
save_audio(out_sig, output_path)
|
| 148 |
+
path_i = output_path
|
| 149 |
+
outputs.append(str(path_i))
|
| 150 |
+
return tuple(outputs)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# PyHARP Metadata
|
| 154 |
+
model_card = ModelCard(
|
| 155 |
+
name="TRIA: The Rhythm In Anything",
|
| 156 |
+
description=(
|
| 157 |
+
"Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n "
|
| 158 |
+
"Rhythm Prompt (tapping, beatboxing, or percussion gesture) "
|
| 159 |
+
"and a Timbre Prompt (an example drum sound or kit recording) \n "
|
| 160 |
+
"It generates 3 drum arrangements that match your groove and chosen timbre. "
|
| 161 |
+
),
|
| 162 |
+
author="Patrick O'Reilly, Julia Barnett, Hugo Flores García, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo",
|
| 163 |
+
tags=["tria", "rhythm-generation", "pyharp"],
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Gradio and PyHARP Endpoint
|
| 168 |
+
with gr.Blocks(title="TRIA") as demo:
|
| 169 |
+
timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True)
|
| 170 |
+
rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True)
|
| 171 |
+
|
| 172 |
+
model_names = list(MODEL_ZOO.keys())
|
| 173 |
+
model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model")
|
| 174 |
+
|
| 175 |
+
with gr.Row():
|
| 176 |
+
cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale")
|
| 177 |
+
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P")
|
| 178 |
+
mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature")
|
| 179 |
+
seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed")
|
| 180 |
+
|
| 181 |
+
out1 = gr.Audio(type="filepath", label="Generated #1")
|
| 182 |
+
out2 = gr.Audio(type="filepath", label="Generated #2")
|
| 183 |
+
out3 = gr.Audio(type="filepath", label="Generated #3")
|
| 184 |
+
|
| 185 |
+
app = build_endpoint(
|
| 186 |
+
model_card=model_card,
|
| 187 |
+
input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed],
|
| 188 |
+
output_components=[out1, out2, out3],
|
| 189 |
+
process_fn=generate_audio,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
demo.queue().launch(share=True, show_error=True)
|
pretrained/.gitignore
ADDED
|
File without changes
|
pretrained/tokenizer/dac/dac_44.1kHz_7.7kbps.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ffa16e9cd52d67dadef026823403481930942f3fead32f44b75c4b60627246a
|
| 3 |
+
size 306721572
|
pretrained/tria/small_musdb_moises_2b/80000/extras.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e18d9b8dbf5c5ff0d86aaf04d2af014960d97eeb396f7743e7595692ee31b68
|
| 3 |
+
size 344556763
|
pretrained/tria/small_musdb_moises_2b/80000/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e20c3850253ba7fb267440573137f4b6099cad1e437fcfd574b84d60138155c
|
| 3 |
+
size 172260091
|
pretrained/tria/small_musdb_moises_2b/best/extras.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e18d9b8dbf5c5ff0d86aaf04d2af014960d97eeb396f7743e7595692ee31b68
|
| 3 |
+
size 344556763
|
pretrained/tria/small_musdb_moises_2b/best/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e20c3850253ba7fb267440573137f4b6099cad1e437fcfd574b84d60138155c
|
| 3 |
+
size 172260091
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.9.0
|
| 2 |
+
torchaudio==2.9.0
|
| 3 |
+
numpy
|
| 4 |
+
argbind
|
| 5 |
+
descript-audiotools>=0.9.2
|
| 6 |
+
pyharp>=1.7.8
|
| 7 |
+
gradio>=4.42.0
|
| 8 |
+
librosa
|
| 9 |
+
soundfile
|
| 10 |
+
tqdm
|
| 11 |
+
|
tria/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.0.1"
|
| 2 |
+
|
| 3 |
+
from . import constants
|
| 4 |
+
from . import util
|
| 5 |
+
from . import features
|
| 6 |
+
from . import transforms
|
tria/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (294 Bytes). View file
|
|
|
tria/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (465 Bytes). View file
|
|
|
tria/__pycache__/features.cpython-310.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
tria/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
tria/constants.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
MANIFESTS_DIR = Path(__file__).parent.parent / "manifests"
|
| 4 |
+
DATA_DIR = Path(__file__).parent.parent / "data"
|
| 5 |
+
PRETRAINED_DIR = Path(__file__).parent.parent / "pretrained"
|
| 6 |
+
ASSETS_DIR = Path(__file__).parent.parent / "assets"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
STEMS = ["drums", "bass", "vocals", "other", "mixture"]
|
| 10 |
+
SAMPLE_RATE = 44_100
|
| 11 |
+
DURATION = 6.0
|
tria/data/__init__.py
ADDED
|
File without changes
|
tria/data/dataset.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Callable
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from typing import List
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
from audiotools import AudioSignal
|
| 12 |
+
from audiotools.core.util import random_state
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
from ..constants import DURATION
|
| 16 |
+
from ..constants import SAMPLE_RATE
|
| 17 |
+
from ..constants import STEMS
|
| 18 |
+
from ..util import collate
|
| 19 |
+
from ..util import get_info
|
| 20 |
+
from ..util import load_audio
|
| 21 |
+
from ..util import rms_salience
|
| 22 |
+
|
| 23 |
+
################################################################################
|
| 24 |
+
# Dataset for loading aligned excerpts across stem classes
|
| 25 |
+
################################################################################
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class StemDataset(Dataset):
|
| 29 |
+
"""
|
| 30 |
+
Load aligned excerpts from specified stem classes given paths in one or more
|
| 31 |
+
CSV manifests. Based on `audiotools.data.datasets.AudioDataset`.
|
| 32 |
+
|
| 33 |
+
Parameters
|
| 34 |
+
----------
|
| 35 |
+
sources : Union[str, Path, List[Union[str, Path]]]
|
| 36 |
+
CSV manifest(s) with columns for each requested stem.
|
| 37 |
+
stems : List[str]
|
| 38 |
+
Column names to load, e.g. ["mixture","drums","bass","vocals"].
|
| 39 |
+
The **first** stem is used for salience unless `salience_on` is set.
|
| 40 |
+
sample_rate : int
|
| 41 |
+
duration : float
|
| 42 |
+
n_examples : int
|
| 43 |
+
num_channels : int
|
| 44 |
+
relative_path : str
|
| 45 |
+
Prepended to relative CSV paths.
|
| 46 |
+
strict : bool
|
| 47 |
+
Drop rows with missing stems (True) vs. fill with silence (False).
|
| 48 |
+
with_replacement : bool
|
| 49 |
+
Sampling strategy for rows.
|
| 50 |
+
shuffle_state : int
|
| 51 |
+
Seed for deterministic per-index RNG.
|
| 52 |
+
loudness_cutoff : Optional[float]
|
| 53 |
+
dB LUFS cutoff; if None, take random excerpt (still shared across stems).
|
| 54 |
+
salience_num_tries : int
|
| 55 |
+
Max tries for salient excerpt search (see `AudioSignal.salient_excerpt`).
|
| 56 |
+
salience_on : Optional[str]
|
| 57 |
+
Which stem to use for salience. Defaults to first of `stems`.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
stems: List[str] = STEMS,
|
| 63 |
+
sample_rate: int = SAMPLE_RATE,
|
| 64 |
+
duration: float = DURATION,
|
| 65 |
+
sources: Union[str, Path, List[Union[str, Path]]] = None,
|
| 66 |
+
source_weights: Optional[List[float]] = None,
|
| 67 |
+
n_examples: int = 1000,
|
| 68 |
+
num_channels: int = 1,
|
| 69 |
+
relative_path: str = "",
|
| 70 |
+
strict: bool = True,
|
| 71 |
+
with_replacement: bool = True,
|
| 72 |
+
shuffle_state: int = 0,
|
| 73 |
+
loudness_cutoff: Optional[float] = -40.0,
|
| 74 |
+
salience_num_tries: int = 8,
|
| 75 |
+
salience_on: Optional[str] = None,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
assert sources is not None
|
| 80 |
+
assert len(stems) >= 1
|
| 81 |
+
|
| 82 |
+
self.stems = list(stems)
|
| 83 |
+
self.sample_rate = int(sample_rate)
|
| 84 |
+
self.duration = float(duration)
|
| 85 |
+
self.num_channels = int(num_channels)
|
| 86 |
+
self.relative_path = Path(relative_path)
|
| 87 |
+
self.strict = strict
|
| 88 |
+
self.with_replacement = with_replacement
|
| 89 |
+
self.length = int(n_examples)
|
| 90 |
+
self.shuffle_state = int(shuffle_state)
|
| 91 |
+
|
| 92 |
+
self.loudness_cutoff = loudness_cutoff
|
| 93 |
+
self.salience_num_tries = int(salience_num_tries)
|
| 94 |
+
self.salience_on = salience_on or self.stems[0]
|
| 95 |
+
if self.salience_on not in self.stems:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"`salience_on` ('{self.salience_on}') must be one of {self.stems}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Read manifests
|
| 101 |
+
csv_paths = [sources] if isinstance(sources, (str, Path)) else list(sources)
|
| 102 |
+
self.source_rows: List[List[Dict]] = []
|
| 103 |
+
kept_mask: List[bool] = []
|
| 104 |
+
kept_csvs: List[Path] = []
|
| 105 |
+
|
| 106 |
+
for cpath in csv_paths:
|
| 107 |
+
# Read rows for source
|
| 108 |
+
cpath = Path(cpath)
|
| 109 |
+
raw_rows = []
|
| 110 |
+
with open(cpath, "r") as f:
|
| 111 |
+
reader = csv.DictReader(f)
|
| 112 |
+
for row in reader:
|
| 113 |
+
entry = {"__manifest__": str(cpath)}
|
| 114 |
+
stem_paths = {}
|
| 115 |
+
for s in self.stems:
|
| 116 |
+
raw = (row.get(s) or "").strip()
|
| 117 |
+
stem_paths[s] = str(self._resolve_path(raw)) if raw else ""
|
| 118 |
+
entry["paths"] = stem_paths
|
| 119 |
+
extra = {k: v for k, v in row.items() if k not in self.stems}
|
| 120 |
+
if extra:
|
| 121 |
+
entry["meta"] = extra
|
| 122 |
+
raw_rows.append(entry)
|
| 123 |
+
|
| 124 |
+
# Filter rows for source
|
| 125 |
+
filtered = []
|
| 126 |
+
for r in raw_rows:
|
| 127 |
+
missing = [
|
| 128 |
+
s for s, p in r["paths"].items() if not p or not Path(p).is_file()
|
| 129 |
+
]
|
| 130 |
+
if self.strict and missing:
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
min_dur = np.inf
|
| 134 |
+
any_valid = False
|
| 135 |
+
for s, p in r["paths"].items():
|
| 136 |
+
if p and Path(p).is_file():
|
| 137 |
+
any_valid = True
|
| 138 |
+
try:
|
| 139 |
+
total_sec = float(sf.info(p).duration)
|
| 140 |
+
min_dur = min(min_dur, float(total_sec))
|
| 141 |
+
except Exception:
|
| 142 |
+
if self.strict:
|
| 143 |
+
min_dur = -np.inf
|
| 144 |
+
break
|
| 145 |
+
if not any_valid or not np.isfinite(min_dur):
|
| 146 |
+
continue
|
| 147 |
+
if min_dur < self.duration and self.strict:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
r["min_duration"] = min_dur if np.isfinite(min_dur) else 0.0
|
| 151 |
+
filtered.append(r)
|
| 152 |
+
|
| 153 |
+
if len(filtered) > 0:
|
| 154 |
+
self.source_rows.append(filtered)
|
| 155 |
+
kept_mask.append(True)
|
| 156 |
+
kept_csvs.append(cpath)
|
| 157 |
+
else:
|
| 158 |
+
kept_mask.append(False)
|
| 159 |
+
|
| 160 |
+
if len(self.source_rows) == 0:
|
| 161 |
+
raise RuntimeError(
|
| 162 |
+
"StemDataset: no valid rows after filtering in any source."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.csv_paths = kept_csvs
|
| 166 |
+
|
| 167 |
+
lengths = [len(lst) for lst in self.source_rows]
|
| 168 |
+
self._source_offsets = np.cumsum([0] + lengths[:-1]) # for global idx
|
| 169 |
+
self._n_rows = int(sum(lengths))
|
| 170 |
+
|
| 171 |
+
# Weights over non-empty sources
|
| 172 |
+
if source_weights is None:
|
| 173 |
+
self._weights = None
|
| 174 |
+
else:
|
| 175 |
+
if len(source_weights) != len(csv_paths):
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"source_weights must match number of sources ({len(csv_paths)}), "
|
| 178 |
+
f"got {len(source_weights)}"
|
| 179 |
+
)
|
| 180 |
+
w = np.asarray(source_weights, dtype=float)
|
| 181 |
+
# Keep only weights for sources that survived filtering
|
| 182 |
+
w = w[np.array(kept_mask, dtype=bool)]
|
| 183 |
+
w = np.clip(w, 0, None)
|
| 184 |
+
if not np.any(w > 0):
|
| 185 |
+
w = np.ones_like(w)
|
| 186 |
+
self._weights = (w / w.sum()).tolist()
|
| 187 |
+
|
| 188 |
+
def _resolve_path(self, p: Union[str, Path]) -> Path:
|
| 189 |
+
p = Path(p).expanduser()
|
| 190 |
+
if not p.is_absolute():
|
| 191 |
+
p = (self.relative_path / p).expanduser()
|
| 192 |
+
return p
|
| 193 |
+
|
| 194 |
+
def _pick_row(self, state: np.random.RandomState):
|
| 195 |
+
# Sample a non-empty source
|
| 196 |
+
sidx = int(state.choice(len(self.source_rows), p=self._weights))
|
| 197 |
+
n_in_source = len(self.source_rows[sidx])
|
| 198 |
+
item_idx = int(state.randint(n_in_source))
|
| 199 |
+
row = self.source_rows[sidx][item_idx]
|
| 200 |
+
|
| 201 |
+
# Map to a global idx for metadata
|
| 202 |
+
ridx_global = int(self._source_offsets[sidx] + item_idx)
|
| 203 |
+
return ridx_global, row
|
| 204 |
+
|
| 205 |
+
def __len__(self):
|
| 206 |
+
return self.length
|
| 207 |
+
|
| 208 |
+
def __getitem__(self, idx: int):
|
| 209 |
+
state = random_state((self.shuffle_state + int(idx)) & 0x7FFFFFFF)
|
| 210 |
+
ridx, row = self._pick_row(state)
|
| 211 |
+
|
| 212 |
+
primary = self.salience_on
|
| 213 |
+
p0 = row["paths"].get(primary, "")
|
| 214 |
+
|
| 215 |
+
offset = 0.0
|
| 216 |
+
primary_sig = None
|
| 217 |
+
if p0 and Path(p0).is_file():
|
| 218 |
+
if self.loudness_cutoff is None or not self.salience_num_tries:
|
| 219 |
+
try:
|
| 220 |
+
total_sec, _sr = get_info(p0)
|
| 221 |
+
except Exception:
|
| 222 |
+
total_sec = 0.0
|
| 223 |
+
max_off = max(0.0, total_sec - self.duration)
|
| 224 |
+
offset = float(state.rand() * max_off) if max_off > 0 else 0.0
|
| 225 |
+
else:
|
| 226 |
+
offset = rms_salience(
|
| 227 |
+
p0,
|
| 228 |
+
duration=self.duration,
|
| 229 |
+
cutoff_db=float(self.loudness_cutoff),
|
| 230 |
+
num_tries=int(self.salience_num_tries),
|
| 231 |
+
state=state,
|
| 232 |
+
)
|
| 233 |
+
primary_sig = load_audio(p0, offset=offset, duration=self.duration)
|
| 234 |
+
else:
|
| 235 |
+
offset = 0.0
|
| 236 |
+
|
| 237 |
+
item: Dict[str, Dict] = {}
|
| 238 |
+
for s in self.stems:
|
| 239 |
+
p = row["paths"][s]
|
| 240 |
+
exists = bool(p) and Path(p).is_file()
|
| 241 |
+
|
| 242 |
+
if s == primary and primary_sig is not None:
|
| 243 |
+
sig = primary_sig.clone() # reuse window we already loaded
|
| 244 |
+
elif exists:
|
| 245 |
+
sig = load_audio(
|
| 246 |
+
p, offset=offset, duration=self.duration
|
| 247 |
+
) # windowed load
|
| 248 |
+
else:
|
| 249 |
+
sig = AudioSignal.zeros(
|
| 250 |
+
self.duration, self.sample_rate, self.num_channels
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Channel formatting
|
| 254 |
+
if self.num_channels == 1:
|
| 255 |
+
sig = sig.to_mono()
|
| 256 |
+
elif self.num_channels != sig.num_channels:
|
| 257 |
+
assert sig.num_channels == 1
|
| 258 |
+
sig.audio_data = sig.audio_data.repeat(1, self.num_channels, 1)
|
| 259 |
+
|
| 260 |
+
# Resample/pad to target SR and exact duration
|
| 261 |
+
sig = sig.resample(self.sample_rate)
|
| 262 |
+
if sig.duration < self.duration:
|
| 263 |
+
sig = sig.zero_pad_to(int(self.duration * self.sample_rate))
|
| 264 |
+
|
| 265 |
+
# Metadata
|
| 266 |
+
sig.metadata["path"] = p
|
| 267 |
+
sig.metadata["offset"] = offset
|
| 268 |
+
sig.metadata["source_row"] = ridx
|
| 269 |
+
if "meta" in row:
|
| 270 |
+
for k, v in row["meta"].items():
|
| 271 |
+
sig.metadata[k] = v
|
| 272 |
+
|
| 273 |
+
item[s] = {"signal": sig, "path": p}
|
| 274 |
+
|
| 275 |
+
item["idx"] = idx
|
| 276 |
+
return item
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
|
| 280 |
+
return collate(list_of_dicts, n_splits=n_splits)
|
tria/data/preprocess.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Callable, Dict, Tuple, Union, Optional, Any
|
| 5 |
+
from rich.progress import track
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from audiotools.core.util import random_state
|
| 10 |
+
from ..util import ensure_dir
|
| 11 |
+
|
| 12 |
+
SplitType = Union[Tuple[float, float, float], Callable[[Path], str]]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_manifests(
|
| 16 |
+
data_dir: Union[str, Path],
|
| 17 |
+
ext: str,
|
| 18 |
+
output_dir: Union[str, Path],
|
| 19 |
+
split: SplitType,
|
| 20 |
+
attributes: Dict[str, Callable[[Path], Any]],
|
| 21 |
+
seed: Optional[int] = 0,
|
| 22 |
+
) -> Dict[str, Path]:
|
| 23 |
+
"""
|
| 24 |
+
Create CSV manifests for audio dataset.
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
data_dir : str
|
| 29 |
+
Dataset root directory to search recursively for files
|
| 30 |
+
ext : str
|
| 31 |
+
Audio file extension
|
| 32 |
+
output_dir : str
|
| 33 |
+
Directory to which to write manifests
|
| 34 |
+
split : SplitType
|
| 35 |
+
Either a 3-tuple containing (train, val, test) proportions summing to 1
|
| 36 |
+
or a Callable that returns "train", "val", or "test" given a filepath
|
| 37 |
+
attributes : dict
|
| 38 |
+
Dictionary mapping column names to Callables for extracting values
|
| 39 |
+
given filepaths; for example {'path': lambda p: str(p)}
|
| 40 |
+
seed : int
|
| 41 |
+
Random seed
|
| 42 |
+
"""
|
| 43 |
+
data_dir = Path(data_dir)
|
| 44 |
+
output_dir = Path(output_dir)
|
| 45 |
+
ensure_dir(output_dir)
|
| 46 |
+
|
| 47 |
+
all_files = sorted(
|
| 48 |
+
[p for p in data_dir.rglob(f"*{ext}") if p.is_file()],
|
| 49 |
+
key=lambda p: str(p).lower(),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
splits = {"train": [], "val": [], "test": []}
|
| 53 |
+
|
| 54 |
+
# Callable split: apply given function to file paths to obtain train/val/test
|
| 55 |
+
# assignments
|
| 56 |
+
if callable(split):
|
| 57 |
+
for p in all_files:
|
| 58 |
+
s = split(p)
|
| 59 |
+
if s not in splits:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Split function must return one of "
|
| 62 |
+
f"{list(splits.keys())}, got {s!r} for {p}"
|
| 63 |
+
)
|
| 64 |
+
splits[s].append(p)
|
| 65 |
+
|
| 66 |
+
# Proportional split: randomly shuffle files and split according to given
|
| 67 |
+
# values
|
| 68 |
+
else:
|
| 69 |
+
if not (isinstance(split, tuple) and len(split) == 3):
|
| 70 |
+
raise ValueError(f"Split proportions tuple must have length 3")
|
| 71 |
+
p_train, p_val, p_test = split
|
| 72 |
+
total = float(p_train + p_val + p_test)
|
| 73 |
+
if not np.isclose(total, 1.0, atol=1e-6):
|
| 74 |
+
raise ValueError(f"Split proportions must sum to 1.0 (got {total}).")
|
| 75 |
+
|
| 76 |
+
rs = random_state(seed)
|
| 77 |
+
idx = np.array(rs.permutation(len(all_files)))
|
| 78 |
+
n = len(idx)
|
| 79 |
+
n_train = int(np.floor(p_train * n))
|
| 80 |
+
n_val = int(np.floor(p_val * n))
|
| 81 |
+
n_test = n - n_train - n_val
|
| 82 |
+
|
| 83 |
+
train_idx = idx[:n_train]
|
| 84 |
+
val_idx = idx[n_train:n_train + n_val]
|
| 85 |
+
test_idx = idx[n_train + n_val:]
|
| 86 |
+
|
| 87 |
+
for i in train_idx:
|
| 88 |
+
splits["train"].append(all_files[int(i)])
|
| 89 |
+
for i in val_idx:
|
| 90 |
+
splits["val"].append(all_files[int(i)])
|
| 91 |
+
for i in test_idx:
|
| 92 |
+
splits["test"].append(all_files[int(i)])
|
| 93 |
+
|
| 94 |
+
columns = list(attributes.keys())
|
| 95 |
+
|
| 96 |
+
# Write CSVs
|
| 97 |
+
out_paths: Dict[str, Path] = {}
|
| 98 |
+
for s in ("train", "val", "test"):
|
| 99 |
+
out_csv = output_dir / f"{s}.csv"
|
| 100 |
+
out_paths[s] = out_csv
|
| 101 |
+
|
| 102 |
+
with out_csv.open("w", newline="") as f:
|
| 103 |
+
writer = csv.DictWriter(f, fieldnames=columns)
|
| 104 |
+
writer.writeheader()
|
| 105 |
+
|
| 106 |
+
for p in track(
|
| 107 |
+
splits[s],
|
| 108 |
+
description=f"Writing {s}.csv",
|
| 109 |
+
total=len(splits[s])
|
| 110 |
+
):
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
row = {}
|
| 114 |
+
for col, fn in attributes.items():
|
| 115 |
+
row[col] = fn(p)
|
| 116 |
+
writer.writerow(row)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(
|
| 119 |
+
f"Error at path {p}:\n"
|
| 120 |
+
f"{e}\n"
|
| 121 |
+
f"Skipping..."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return out_paths
|
tria/features.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from audiotools import AudioSignal
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
################################################################################
|
| 6 |
+
# Utilities for extracting rhythm feature representations
|
| 7 |
+
################################################################################
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _moving_average(x: torch.Tensor, window_length: int):
|
| 11 |
+
"""
|
| 12 |
+
Smooth features with moving average over frames.
|
| 13 |
+
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
x : torch.Tensor
|
| 17 |
+
Shape (n_batch, n_feats, n_frames)
|
| 18 |
+
window_length : int
|
| 19 |
+
Smoothing window length
|
| 20 |
+
"""
|
| 21 |
+
if window_length <= 1:
|
| 22 |
+
return x
|
| 23 |
+
n_feats = x.shape[1]
|
| 24 |
+
kernel = torch.ones(
|
| 25 |
+
(n_feats, 1, window_length),
|
| 26 |
+
device=x.device, dtype=x.dtype
|
| 27 |
+
) / window_length
|
| 28 |
+
|
| 29 |
+
pad_left = (window_length - 1) // 2
|
| 30 |
+
pad_right = window_length // 2
|
| 31 |
+
x_pad = torch.nn.functional.pad(x, (pad_left, pad_right), mode="reflect")
|
| 32 |
+
|
| 33 |
+
# Smooth separately over feature channels
|
| 34 |
+
return torch.nn.functional.conv1d(x_pad, kernel, groups=n_feats)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# The 'original' TRIA features can be recovered using:
|
| 38 |
+
# * `slow_ma_ms` = None
|
| 39 |
+
# * `post_smooth_ms` = None
|
| 40 |
+
# * `legacy_normalize` = True
|
| 41 |
+
def rhythm_features(
|
| 42 |
+
signal: AudioSignal,
|
| 43 |
+
sample_rate: int = 44_100,
|
| 44 |
+
n_bands: int = 2,
|
| 45 |
+
n_mels: int = 80,
|
| 46 |
+
window_length: int = 1024,
|
| 47 |
+
hop_length: int = 512,
|
| 48 |
+
normalize_quantile: float = 0.98,
|
| 49 |
+
quantization_levels: int = 33,
|
| 50 |
+
clamp_max: float = 50.0,
|
| 51 |
+
eps: float = 1e-8,
|
| 52 |
+
slow_ma_ms: float = 100.0,
|
| 53 |
+
post_smooth_ms: float = 10.0,
|
| 54 |
+
legacy_normalize: bool = False,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Extract multi-band 'rhythm' features from audio by adaptively splitting
|
| 58 |
+
spectrogram along frequency axis and applying normalization, quantization,
|
| 59 |
+
and smoothing / sparsity filtering.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
signal : AudioSignal
|
| 64 |
+
Audio from which to extract features
|
| 65 |
+
sample_rate : int
|
| 66 |
+
Sample rate at which to extract features
|
| 67 |
+
n_bands : int
|
| 68 |
+
Number of frequency bands into which to adaptively divide spectrogram
|
| 69 |
+
n_mels : int
|
| 70 |
+
Number of base mel frequency bins in spectrogram
|
| 71 |
+
window_length : int
|
| 72 |
+
Spectrogram window length
|
| 73 |
+
hop_length : int
|
| 74 |
+
Spectrogram hop length
|
| 75 |
+
normalize_quantile : float
|
| 76 |
+
Optionally normalize each band relative to top-p largest magnitude
|
| 77 |
+
rather than absolute max
|
| 78 |
+
quantization_levels : int
|
| 79 |
+
Number of bins into which feature magnitudes are quantized
|
| 80 |
+
clamp_max : float
|
| 81 |
+
Maximum allowed spectrogram magnitude
|
| 82 |
+
eps : float
|
| 83 |
+
For numerical stability
|
| 84 |
+
slow_ma_ms : float
|
| 85 |
+
Smoothing filter length in milliseconds for transient emphasis (smoothed
|
| 86 |
+
features are subtracted)
|
| 87 |
+
post_smooth_ms : float
|
| 88 |
+
Smoothing filter length in milliseconds for transient smoothing
|
| 89 |
+
legacy_normalize : bool
|
| 90 |
+
If `True`, use mean/std and sigmoid normalization as described in
|
| 91 |
+
original TRIA paper
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
assert n_bands >= 1
|
| 95 |
+
assert quantization_levels >= 2
|
| 96 |
+
|
| 97 |
+
# Loudness normalization
|
| 98 |
+
signal = signal.clone().to_mono().resample(sample_rate).normalize(-16.)
|
| 99 |
+
signal.ensure_max_of_audio()
|
| 100 |
+
|
| 101 |
+
# Clamped mel spectrogram
|
| 102 |
+
mel = signal.mel_spectrogram(
|
| 103 |
+
n_mels=n_mels,
|
| 104 |
+
hop_length=hop_length,
|
| 105 |
+
window_length=window_length,
|
| 106 |
+
).mean(1) # (n_batch, n_mels, n_frames)
|
| 107 |
+
mel = torch.clamp(mel, 0.0, clamp_max)
|
| 108 |
+
|
| 109 |
+
n_batch, _, n_frames = mel.shape
|
| 110 |
+
|
| 111 |
+
if legacy_normalize:
|
| 112 |
+
# Original normalization: divide by number of mels
|
| 113 |
+
mel = mel / n_mels
|
| 114 |
+
else:
|
| 115 |
+
# Compress logarithmically
|
| 116 |
+
mel = torch.log1p(mel) / torch.log1p(torch.tensor(clamp_max, device=mel.device, dtype=mel.dtype))
|
| 117 |
+
|
| 118 |
+
# Split spectrogram into bands adaptively
|
| 119 |
+
energy_per_bin = mel.mean(dim=-1) # (n_batch, n_mels)
|
| 120 |
+
cum = energy_per_bin.cumsum(dim=1) # (n_batch, n_mels)
|
| 121 |
+
total = cum[:, -1:] # (n_batch, 1)
|
| 122 |
+
|
| 123 |
+
if n_bands == 1:
|
| 124 |
+
bands = mel.sum(dim=1, keepdim=True) # (n_batch, 1, n_frames)
|
| 125 |
+
else:
|
| 126 |
+
targets = torch.linspace(
|
| 127 |
+
1.0 / n_bands, (n_bands - 1) / n_bands, n_bands - 1,
|
| 128 |
+
device=mel.device, dtype=mel.dtype
|
| 129 |
+
)[None, :] * total # (n_batch, n_bands-1)
|
| 130 |
+
|
| 131 |
+
edges = torch.searchsorted(cum, targets, right=False) # (n_batch, n_bands-1)
|
| 132 |
+
|
| 133 |
+
cuts = torch.cat(
|
| 134 |
+
[
|
| 135 |
+
torch.zeros(n_batch, 1, dtype=torch.long, device=mel.device),
|
| 136 |
+
edges + 1,
|
| 137 |
+
torch.full((n_batch, 1), mel.size(1), dtype=torch.long, device=mel.device),
|
| 138 |
+
],
|
| 139 |
+
dim=1
|
| 140 |
+
) # (n_batch, n_bands+1)
|
| 141 |
+
|
| 142 |
+
prefix = mel.cumsum(dim=1) # (n_batch, n_mels, n_frames)
|
| 143 |
+
prefix_pad = torch.cat(
|
| 144 |
+
[torch.zeros(n_batch, 1, n_frames, device=mel.device, dtype=mel.dtype), prefix],
|
| 145 |
+
dim=1
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
a_idx = cuts[:, :-1].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
|
| 149 |
+
b_idx = cuts[:, 1: ].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
|
| 150 |
+
bands = prefix_pad.gather(1, b_idx) - prefix_pad.gather(1, a_idx) # (n_batch, n_bands, n_frames)
|
| 151 |
+
|
| 152 |
+
# Emphasize transients by subtracting smoothed features
|
| 153 |
+
transient = bands.clone()
|
| 154 |
+
to_frames = lambda ms: max(1, int(round((ms / 1000.0) * sample_rate / hop_length)))
|
| 155 |
+
|
| 156 |
+
if slow_ma_ms is not None:
|
| 157 |
+
slow_win = to_frames(slow_ma_ms)
|
| 158 |
+
bands_slow = _moving_average(bands, slow_win) # (n_batch, n_bands, n_frames)
|
| 159 |
+
transient = torch.relu(bands - bands_slow)
|
| 160 |
+
|
| 161 |
+
# Apply additional smoothing to transients
|
| 162 |
+
if post_smooth_ms is not None:
|
| 163 |
+
ps_win = to_frames(post_smooth_ms)
|
| 164 |
+
if ps_win > 1:
|
| 165 |
+
transient = _moving_average(transient, ps_win)
|
| 166 |
+
|
| 167 |
+
# Normalize features across time per band
|
| 168 |
+
if legacy_normalize:
|
| 169 |
+
# Original normalization (mean/std with sigmoid compression)
|
| 170 |
+
mean = transient.mean(dim=-1, keepdim=True)
|
| 171 |
+
std = transient.std(dim=-1, keepdim=True).clamp_min(eps)
|
| 172 |
+
transient = torch.sigmoid((transient - mean) / std)
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
# Quantile-based normalization
|
| 176 |
+
q = torch.quantile(
|
| 177 |
+
transient.clamp_min(0.0),
|
| 178 |
+
q=normalize_quantile,
|
| 179 |
+
dim=-1,
|
| 180 |
+
keepdim=True
|
| 181 |
+
).clamp_min(eps)
|
| 182 |
+
transient = (transient / q).clamp(0.0, 1.0)
|
| 183 |
+
|
| 184 |
+
# Quantize feature intensities into bins to ensure a tight information
|
| 185 |
+
# bottleneck
|
| 186 |
+
steps = quantization_levels - 1
|
| 187 |
+
return torch.round(transient * steps) / steps
|
tria/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tria import TRIA
|
tria/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (182 Bytes). View file
|
|
|
tria/model/__pycache__/mask.cpython-310.pyc
ADDED
|
Binary file (5.87 kB). View file
|
|
|
tria/model/__pycache__/sample.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
tria/model/__pycache__/tria.cpython-310.pyc
ADDED
|
Binary file (7.21 kB). View file
|
|
|
tria/model/mask.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from audiotools.core.util import random_state
|
| 6 |
+
|
| 7 |
+
################################################################################
|
| 8 |
+
# Utilities for masked language modeling
|
| 9 |
+
################################################################################
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cosine_schedule(t: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
"""
|
| 14 |
+
Map timestep in [0, 1] to masking ratio in (0, 1] via cosine schedule
|
| 15 |
+
proposed by Chang et al. in "MaskGIT: Masked generative image
|
| 16 |
+
transformer" (2022).
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
t : torch.Tensor
|
| 21 |
+
Timestep in [0, 1]
|
| 22 |
+
|
| 23 |
+
Returns
|
| 24 |
+
-------
|
| 25 |
+
torch.Tensor
|
| 26 |
+
Mask proportion in (0, 1]
|
| 27 |
+
"""
|
| 28 |
+
return (t * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def format_seed(seed):
|
| 32 |
+
if isinstance(seed, (int, float)):
|
| 33 |
+
seed = [seed]
|
| 34 |
+
elif isinstance(seed, torch.Tensor):
|
| 35 |
+
seed = seed.tolist()
|
| 36 |
+
elif isinstance(seed, Iterable):
|
| 37 |
+
pass
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Invalid random seed of type {type(seed)}")
|
| 40 |
+
|
| 41 |
+
return [random_state(s) for s in seed]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_span_mask(
|
| 45 |
+
tokens: torch.Tensor,
|
| 46 |
+
min_prop: float,
|
| 47 |
+
max_prop: float,
|
| 48 |
+
seed: Union[int, Iterable[int]],
|
| 49 |
+
) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Mask a random span of consecutive frames across all codebooks, varying
|
| 52 |
+
across batch.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
tokens : torch.Tensor
|
| 57 |
+
Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
|
| 58 |
+
min_prop : float
|
| 59 |
+
Minimum proportion of frames to mask
|
| 60 |
+
max_prop : float
|
| 61 |
+
Maximum proportion of frames to mask
|
| 62 |
+
seed : Iterable[int]
|
| 63 |
+
One or more random seeds to determine masks
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
torch.Tensor
|
| 68 |
+
Mask of shape (n_batch, n_frames)
|
| 69 |
+
"""
|
| 70 |
+
assert min_prop >= 0.0
|
| 71 |
+
assert max_prop <= 1.0
|
| 72 |
+
|
| 73 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 74 |
+
|
| 75 |
+
states = format_seed(seed)
|
| 76 |
+
assert len(states) == n_batch
|
| 77 |
+
|
| 78 |
+
mask = torch.ones(
|
| 79 |
+
n_batch,
|
| 80 |
+
n_frames,
|
| 81 |
+
device=tokens.device,
|
| 82 |
+
dtype=torch.bool,
|
| 83 |
+
) # (n_batch, n_frames)
|
| 84 |
+
|
| 85 |
+
for i, s in enumerate(states):
|
| 86 |
+
prop = s.uniform(min_prop, max_prop) if min_prop < max_prop else min_prop
|
| 87 |
+
|
| 88 |
+
if prop >= 1.0:
|
| 89 |
+
mask[i] = False
|
| 90 |
+
else:
|
| 91 |
+
span = int(prop * n_frames)
|
| 92 |
+
st = s.randint(0, max(n_frames - span, 1))
|
| 93 |
+
mask[i, st : st + span] = False
|
| 94 |
+
|
| 95 |
+
return mask
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_current_codebook_mask(
|
| 99 |
+
tokens: torch.Tensor, codebooks: torch.Tensor
|
| 100 |
+
) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Given tokens and batch of selected codebooks, mask all codebooks "above" and
|
| 103 |
+
"below" selected codebooks.
|
| 104 |
+
|
| 105 |
+
Parameters
|
| 106 |
+
----------
|
| 107 |
+
tokens : torch.Tensor
|
| 108 |
+
Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
|
| 109 |
+
codebooks : torch.Tensor
|
| 110 |
+
Selected codebooks "above" which tokens should be masked, shape
|
| 111 |
+
(n_batch,)
|
| 112 |
+
|
| 113 |
+
Returns
|
| 114 |
+
-------
|
| 115 |
+
torch.Tensor
|
| 116 |
+
Mask of shape (n_batch, n_codebooks)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 120 |
+
|
| 121 |
+
assert codebooks.ndim == 1
|
| 122 |
+
assert codebooks.shape[0] in [1, n_batch]
|
| 123 |
+
codebooks = codebooks.repeat(n_batch // codebooks.shape[0])
|
| 124 |
+
|
| 125 |
+
mask = (
|
| 126 |
+
torch.arange(
|
| 127 |
+
n_codebooks,
|
| 128 |
+
dtype=codebooks.dtype,
|
| 129 |
+
device=codebooks.device,
|
| 130 |
+
)[None, :]
|
| 131 |
+
== codebooks[:, None]
|
| 132 |
+
) # (n_batch, n_codebooks)
|
| 133 |
+
|
| 134 |
+
return mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_next_codebooks_mask(
|
| 138 |
+
tokens: torch.Tensor, codebooks: torch.Tensor
|
| 139 |
+
) -> torch.Tensor:
|
| 140 |
+
"""
|
| 141 |
+
Given tokens and batch of selected codebooks, mask all codebooks "above"
|
| 142 |
+
selected codebooks.
|
| 143 |
+
|
| 144 |
+
Parameters
|
| 145 |
+
----------
|
| 146 |
+
tokens : torch.Tensor
|
| 147 |
+
Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
|
| 148 |
+
codebooks : torch.Tensor
|
| 149 |
+
Selected codebooks "above" which tokens should be masked, shape
|
| 150 |
+
(n_batch,)
|
| 151 |
+
|
| 152 |
+
Returns
|
| 153 |
+
-------
|
| 154 |
+
torch.Tensor
|
| 155 |
+
Mask of shape (n_batch, n_codebooks)
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 159 |
+
|
| 160 |
+
assert codebooks.ndim == 1
|
| 161 |
+
assert codebooks.shape[0] in [1, n_batch]
|
| 162 |
+
codebooks = codebooks.repeat(n_batch // codebooks.shape[0])
|
| 163 |
+
|
| 164 |
+
mask = (
|
| 165 |
+
torch.arange(
|
| 166 |
+
n_codebooks,
|
| 167 |
+
dtype=codebooks.dtype,
|
| 168 |
+
device=codebooks.device,
|
| 169 |
+
)[None, :]
|
| 170 |
+
<= codebooks[:, None]
|
| 171 |
+
) # (n_batch, n_codebooks)
|
| 172 |
+
|
| 173 |
+
return mask
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_random_mask(
|
| 177 |
+
tokens: torch.Tensor,
|
| 178 |
+
prop: Union[float, Iterable[float]],
|
| 179 |
+
seed: Union[int, Iterable[int]],
|
| 180 |
+
) -> torch.Tensor:
|
| 181 |
+
"""
|
| 182 |
+
Parameters
|
| 183 |
+
----------
|
| 184 |
+
tokens : torch.Tensor
|
| 185 |
+
Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
|
| 186 |
+
prop : torch.Tensor
|
| 187 |
+
Proportion of tokens to be masked, shape (n_batch,)
|
| 188 |
+
seed : Iterable[int]
|
| 189 |
+
One or more random seeds to determine masks
|
| 190 |
+
|
| 191 |
+
Returns
|
| 192 |
+
-------
|
| 193 |
+
torch.Tensor
|
| 194 |
+
Random mask of shape (n_batch, n_codebooks, n_frames)
|
| 195 |
+
"""
|
| 196 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 197 |
+
|
| 198 |
+
if isinstance(prop, torch.Tensor):
|
| 199 |
+
prop = prop.tolist()
|
| 200 |
+
assert len(prop) == n_batch
|
| 201 |
+
|
| 202 |
+
states = format_seed(seed)
|
| 203 |
+
assert len(states) == n_batch
|
| 204 |
+
|
| 205 |
+
mask = torch.ones(
|
| 206 |
+
n_batch,
|
| 207 |
+
n_codebooks,
|
| 208 |
+
n_frames,
|
| 209 |
+
device=tokens.device,
|
| 210 |
+
dtype=torch.bool,
|
| 211 |
+
) # (n_batch, n_codebooks, n_frames)
|
| 212 |
+
|
| 213 |
+
for i, (s, p) in enumerate(zip(states, prop)):
|
| 214 |
+
mask[i] = torch.from_numpy(s.rand(n_codebooks, n_frames)).to(mask.device) > p
|
| 215 |
+
|
| 216 |
+
return mask
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def combine_masks(
|
| 220 |
+
mask_span: torch.Tensor,
|
| 221 |
+
mask_current_codebook: torch.Tensor,
|
| 222 |
+
mask_next_codebooks: torch.Tensor,
|
| 223 |
+
mask_random: torch.Tensor,
|
| 224 |
+
leak: bool = False,
|
| 225 |
+
) -> torch.Tensor:
|
| 226 |
+
"""
|
| 227 |
+
Combine sampled masks to allow for application to token buffer.
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
mask_span : torch.Tensor
|
| 232 |
+
Shape (n_batch, n_frames)
|
| 233 |
+
mask_current_codebook : torch.Tensor
|
| 234 |
+
Shape (n_batch, n_codebooks)
|
| 235 |
+
mask_next_codebooks : torch.Tensor
|
| 236 |
+
Shape (n_batch, n_codebooks)
|
| 237 |
+
mask_random : torch.Tensor
|
| 238 |
+
Shape (n_batch, n_codebooks, n_frames)
|
| 239 |
+
|
| 240 |
+
Returns
|
| 241 |
+
-------
|
| 242 |
+
torch.Tensor
|
| 243 |
+
Combined mask, shape (n_batch, n_codebooks, n_frames)
|
| 244 |
+
torch.Tensor
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
mask_current_level = mask_current_codebook[:, :, None] & (~mask_random)
|
| 248 |
+
|
| 249 |
+
if leak:
|
| 250 |
+
# Allow leakage from "higher" codebooks inside masked span
|
| 251 |
+
higher = (~mask_next_codebooks[:, :, None]) & (~mask_random)
|
| 252 |
+
else:
|
| 253 |
+
# Strictly mask "higher" codebooks inside masked span
|
| 254 |
+
higher = ~mask_next_codebooks[:, :, None]
|
| 255 |
+
|
| 256 |
+
# Inside span, unmask everything except "higher" codebooks and masked
|
| 257 |
+
# positions in current codebook
|
| 258 |
+
mask = ~(higher | mask_current_level)
|
| 259 |
+
|
| 260 |
+
# Outside span, fully unmask
|
| 261 |
+
mask = mask | mask_span[:, None, :]
|
| 262 |
+
|
| 263 |
+
return mask
|
tria/model/sample.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from typing import Iterable, Union, Optional
|
| 6 |
+
import numpy as np
|
| 7 |
+
from numpy.random import RandomState
|
| 8 |
+
|
| 9 |
+
from .mask import cosine_schedule, format_seed
|
| 10 |
+
|
| 11 |
+
################################################################################
|
| 12 |
+
# Utilities for sampling from trained TRIA model
|
| 13 |
+
################################################################################
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def top_p_top_k(
|
| 17 |
+
logits: torch.Tensor,
|
| 18 |
+
top_p: float = None,
|
| 19 |
+
top_k: int = None,
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
|
| 23 |
+
Garcia. See: https://github.com/hugofloresgarcia/vampnet/
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
logits : torch.Tensor
|
| 28 |
+
Shape (..., n_classes)
|
| 29 |
+
"""
|
| 30 |
+
logits = logits.clone()
|
| 31 |
+
n_classes = logits.shape[-1]
|
| 32 |
+
|
| 33 |
+
# Mask logits outside top-k by setting to -inf
|
| 34 |
+
if top_k is not None and 0 < top_k < n_classes:
|
| 35 |
+
thresh = logits.topk(top_k, dim=-1).values[..., -1:] # (..., 1)
|
| 36 |
+
logits[logits < thresh] = float("-inf")
|
| 37 |
+
|
| 38 |
+
# Mask logits outside top-p by setting to -inf
|
| 39 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 40 |
+
# Sort descending
|
| 41 |
+
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) # (..., n_classes)
|
| 42 |
+
sorted_probs = F.softmax(sorted_logits, dim=-1) # (..., n_classes)
|
| 43 |
+
cumsum = sorted_probs.cumsum(dim=-1) # (..., n_classes)
|
| 44 |
+
|
| 45 |
+
# Keep at least one logit
|
| 46 |
+
to_remove = cumsum > top_p
|
| 47 |
+
to_remove[..., 0] = False
|
| 48 |
+
remove_idx = torch.zeros_like(to_remove).scatter(-1, sorted_idx, to_remove)
|
| 49 |
+
logits[remove_idx] = float("-inf")
|
| 50 |
+
|
| 51 |
+
return logits
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def sample(
|
| 55 |
+
logits: torch.Tensor,
|
| 56 |
+
temp: float,
|
| 57 |
+
argmax: bool = False,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
|
| 61 |
+
Garcia. See: https://github.com/hugofloresgarcia/vampnet/
|
| 62 |
+
|
| 63 |
+
Parameters
|
| 64 |
+
----------
|
| 65 |
+
logits : torch.Tensor
|
| 66 |
+
Shape (..., n_classes)
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
torch.Tensor
|
| 71 |
+
Sampled tokens, shape of `logits` with trailing `n_classes` dimension
|
| 72 |
+
removed
|
| 73 |
+
torch.Tensor
|
| 74 |
+
Probabilities of sampled tokens, shape of `logits` with trailing
|
| 75 |
+
`n_classes` dimension removed
|
| 76 |
+
"""
|
| 77 |
+
if temp <= 0:
|
| 78 |
+
argmax = True
|
| 79 |
+
temp = 1.0
|
| 80 |
+
|
| 81 |
+
if argmax:
|
| 82 |
+
sampled = logits.argmax(dim=-1)
|
| 83 |
+
probs = F.softmax(
|
| 84 |
+
logits, dim=-1
|
| 85 |
+
).take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
|
| 86 |
+
return sampled, probs
|
| 87 |
+
|
| 88 |
+
probs = F.softmax(logits / temp, dim=-1)
|
| 89 |
+
flat = probs.reshape(-1, probs.shape[-1])
|
| 90 |
+
draws = torch.multinomial(flat, 1).squeeze(-1)
|
| 91 |
+
sampled = draws.view(*probs.shape[:-1])
|
| 92 |
+
chosen = probs.take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
|
| 93 |
+
return sampled, chosen
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def mask_by_confidence(
|
| 97 |
+
probs: torch.Tensor,
|
| 98 |
+
n: torch.Tensor,
|
| 99 |
+
temp: float,
|
| 100 |
+
causal_bias: float,
|
| 101 |
+
state: Iterable[RandomState],
|
| 102 |
+
eligible: Optional[torch.Tensor] = None,
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Re-mask predicted tokens in a single codebook such that `n` previously-
|
| 106 |
+
masked tokens are left unmasked, using confidence (probability assigned to
|
| 107 |
+
tokens during sampling) to select which tokens remain. This confidence can
|
| 108 |
+
be mediated by random noise and a bias to unmask early (leftward) positions
|
| 109 |
+
first.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
probs : torch.Tensor
|
| 114 |
+
Probabilities assigned to sampled tokens, shape (n_batch, n_frames)
|
| 115 |
+
n : torch.Tensor
|
| 116 |
+
Target number of unmasked tokens, shape (n_batch,)
|
| 117 |
+
temp : float
|
| 118 |
+
Mask temperature, corresponding to randomness in unmasking process
|
| 119 |
+
causal_bias : float
|
| 120 |
+
Bias towards unmasking early (leftward) token positions first; typically
|
| 121 |
+
in (0, 1]. Note that large values of `temp` can effectively "wash out"
|
| 122 |
+
this causal bias
|
| 123 |
+
state : Iterable[RandomState]
|
| 124 |
+
Random seeds for reproducibility
|
| 125 |
+
eligible : torch.Tensor
|
| 126 |
+
Optional indicator for positions eligible for unmasking, shape (n_batch, n_frames)
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
n_batch, n_frames = probs.shape
|
| 131 |
+
device = probs.device
|
| 132 |
+
|
| 133 |
+
if eligible is None:
|
| 134 |
+
eligible = torch.isfinite(probs) & (probs > 0)
|
| 135 |
+
else:
|
| 136 |
+
eligible = eligible.to(torch.bool)
|
| 137 |
+
|
| 138 |
+
# Masked token count and target
|
| 139 |
+
n_masked = eligible.long().sum(dim=-1)
|
| 140 |
+
n_unmask = (n_masked - n).clamp_min(0)
|
| 141 |
+
|
| 142 |
+
# Gumbel noise to introduce randomness into unmasking
|
| 143 |
+
u = torch.stack([
|
| 144 |
+
torch.from_numpy(s.uniform(1e-6, 1 - 1e-6, n_frames)) for s in state
|
| 145 |
+
], dim=0).to(probs)
|
| 146 |
+
gumbel = -torch.log(-torch.log(u))
|
| 147 |
+
|
| 148 |
+
# Log-confidences + noise
|
| 149 |
+
s = probs.clamp_min(1e-12)
|
| 150 |
+
confs = torch.log(s) + temp * gumbel
|
| 151 |
+
|
| 152 |
+
# Optional causal bias in log-domain
|
| 153 |
+
if causal_bias > 0:
|
| 154 |
+
frame_relpos = (1 - (torch.arange(n_frames, device=device, dtype=confs.dtype) + 1) / n_frames).view(1, -1)
|
| 155 |
+
confs = confs + causal_bias * frame_relpos
|
| 156 |
+
|
| 157 |
+
# Only eligible positions can be chosen
|
| 158 |
+
confs_masked = confs.masked_fill(~eligible, float("-inf"))
|
| 159 |
+
sorted_vals, sorted_idx = confs_masked.sort(dim=-1, descending=True)
|
| 160 |
+
rank = torch.arange(n_frames, device=device).view(1, n_frames).expand_as(confs_masked)
|
| 161 |
+
k = n_unmask.view(n_batch, 1)
|
| 162 |
+
pick_sorted = rank < k
|
| 163 |
+
pick = torch.zeros_like(pick_sorted, dtype=torch.bool).scatter(-1, sorted_idx, pick_sorted)
|
| 164 |
+
|
| 165 |
+
# Return tokens_mask semantics (True = unmasked/keep)
|
| 166 |
+
mask = ~(eligible & (~pick))
|
| 167 |
+
return mask
|
| 168 |
+
|
tria/model/tria.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Union, Iterable
|
| 4 |
+
|
| 5 |
+
from ..nn.transformer import Transformer
|
| 6 |
+
from .mask import cosine_schedule, format_seed
|
| 7 |
+
from .sample import mask_by_confidence, top_p_top_k, sample
|
| 8 |
+
|
| 9 |
+
################################################################################
|
| 10 |
+
# TRIA masked language model
|
| 11 |
+
################################################################################
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TRIA(torch.nn.Module):
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
codebook_size: int = 1024,
|
| 19 |
+
n_codebooks: int = 9,
|
| 20 |
+
n_feats: int = 2,
|
| 21 |
+
n_channels: int = 512,
|
| 22 |
+
n_heads: int = 8,
|
| 23 |
+
n_layers: int = 12,
|
| 24 |
+
mult: int = 4,
|
| 25 |
+
p_dropout: float = 0.0,
|
| 26 |
+
p_token_dropout: float = 0.0,
|
| 27 |
+
bias: bool = False,
|
| 28 |
+
max_len: int = 8192,
|
| 29 |
+
pos_enc: Optional[str] = "rope",
|
| 30 |
+
qk_norm: bool = True,
|
| 31 |
+
use_sdpa: bool = True,
|
| 32 |
+
interp: str = "nearest",
|
| 33 |
+
share_emb: bool = True,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
assert interp in ["nearest", "linear"]
|
| 38 |
+
|
| 39 |
+
self.adapter = torch.nn.Linear(n_feats, n_channels, bias=bias)
|
| 40 |
+
self.in_proj = torch.nn.Linear(2 * n_channels, n_channels, bias=bias)
|
| 41 |
+
|
| 42 |
+
self.backbone = Transformer(
|
| 43 |
+
n_channels=n_channels,
|
| 44 |
+
n_heads=n_heads,
|
| 45 |
+
n_layers=n_layers,
|
| 46 |
+
mult=mult,
|
| 47 |
+
p_dropout=p_dropout,
|
| 48 |
+
bias=False,
|
| 49 |
+
max_len=max_len,
|
| 50 |
+
pos_enc_self_attn=pos_enc,
|
| 51 |
+
qk_norm=qk_norm,
|
| 52 |
+
use_sdpa=use_sdpa,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.tokens_emb = torch.nn.Embedding(codebook_size * n_codebooks, n_channels)
|
| 56 |
+
self.head = torch.nn.Linear(n_channels, codebook_size * n_codebooks, bias=False) # No bias on head, to allow weight-sharing
|
| 57 |
+
if share_emb:
|
| 58 |
+
self.tokens_emb.weight = self.head.weight
|
| 59 |
+
|
| 60 |
+
# Masked token embedding
|
| 61 |
+
self.tokens_mask_emb = torch.nn.Parameter(torch.zeros(n_channels))
|
| 62 |
+
|
| 63 |
+
# Attributes
|
| 64 |
+
self.p_token_dropout = p_token_dropout
|
| 65 |
+
self.codebook_size = codebook_size
|
| 66 |
+
self.n_codebooks = n_codebooks
|
| 67 |
+
self.n_feats = n_feats
|
| 68 |
+
self.n_channels = n_channels
|
| 69 |
+
self.n_layers = n_layers
|
| 70 |
+
self.interp = interp
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
tokens: torch.Tensor,
|
| 75 |
+
feats: torch.Tensor,
|
| 76 |
+
codebook: torch.Tensor,
|
| 77 |
+
tokens_mask: torch.Tensor,
|
| 78 |
+
feats_mask: torch.Tensor,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
"""
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
tokens : torch.Tensor
|
| 84 |
+
Acoustic tokens, fully or partially masked; shape
|
| 85 |
+
(n_batch, n_codebooks, n_frames)
|
| 86 |
+
feats : torch.Tensor
|
| 87 |
+
Aligned features to guide generation; shape (n_batch, n_feats, n_frames)
|
| 88 |
+
codebook : torch.Tensor
|
| 89 |
+
Codebook in which to predict masked tokens; shape (n_batch,)
|
| 90 |
+
tokens_mask : torch.Tensor
|
| 91 |
+
Boolean tensor indicating umasked token positions (True where
|
| 92 |
+
unmasked, False where masked); shape (n_batch, n_codebooks, n_frames)
|
| 93 |
+
feats_mask : torch.Tensor
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
assert tokens.ndim == 3 # (n_batch, n_codebooks, n_frames)
|
| 97 |
+
assert feats.ndim == 3 # (n_batch, n_feats, n_frames')
|
| 98 |
+
assert tokens_mask.ndim == 3 # (n_batch, n_codebooks, n_frames)
|
| 99 |
+
assert feats_mask.ndim == 2 # (n_batch, n_frames')
|
| 100 |
+
assert tokens.shape[1] == self.n_codebooks
|
| 101 |
+
|
| 102 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 103 |
+
|
| 104 |
+
# Interpolate features and mask to tokens resulution
|
| 105 |
+
feats = torch.nn.functional.interpolate(feats, n_frames, mode=self.interp)
|
| 106 |
+
feats_mask = torch.nn.functional.interpolate(
|
| 107 |
+
feats_mask[:, None, :].float(), n_frames, mode="nearest")
|
| 108 |
+
|
| 109 |
+
# Adapt features
|
| 110 |
+
feats = self.adapter(feats.transpose(1, 2)) # (n_batch, n_frames, n_channels)
|
| 111 |
+
|
| 112 |
+
# Embed tokens
|
| 113 |
+
codebook_offsets = torch.arange(
|
| 114 |
+
n_codebooks, dtype=tokens.dtype, device=tokens.device
|
| 115 |
+
).reshape(1, -1, 1) * self.codebook_size # (1, n_codebooks, 1)
|
| 116 |
+
tokens = tokens + codebook_offsets # (n_batch, n_codebooks, n_frames)
|
| 117 |
+
tokens_emb = self.tokens_emb(tokens) # (n_batch, n_codebooks, n_frames, n_channels)
|
| 118 |
+
|
| 119 |
+
# Zero masked token embeddings
|
| 120 |
+
tokens_emb = tokens_emb * tokens_mask.unsqueeze(-1).float()
|
| 121 |
+
|
| 122 |
+
# Apply learned embedding to masked token positions in current codebook
|
| 123 |
+
mask_pos = torch.arange(
|
| 124 |
+
n_codebooks, dtype=tokens.dtype, device=tokens.device
|
| 125 |
+
)[None, :] == codebook[:, None] # (n_batch, n_codebooks)
|
| 126 |
+
mask_pos = torch.logical_and(mask_pos.unsqueeze(-1), ~tokens_mask) # (n_batch, n_codebooks, n_frames)
|
| 127 |
+
|
| 128 |
+
tokens_emb = tokens_emb + (
|
| 129 |
+
mask_pos.unsqueeze(-1).float()
|
| 130 |
+
) * self.tokens_mask_emb.reshape(1, 1, 1, -1) # (n_batch, n_codebooks, n_frames, n_channels)
|
| 131 |
+
|
| 132 |
+
# Token dropout (encourage attention to unmasked frames)
|
| 133 |
+
if self.training and self.p_token_dropout > 0.0:
|
| 134 |
+
|
| 135 |
+
# Apply dropout within masked frames and "below" current codebook
|
| 136 |
+
below = torch.arange(
|
| 137 |
+
n_codebooks, device=tokens.device
|
| 138 |
+
)[None, :, None] < codebook[:, None, None] # (n_batch, n_codebooks, 1)
|
| 139 |
+
eligible = below & feats_mask.bool() # (n_batch, n_codebooks, n_frames)
|
| 140 |
+
drop = (
|
| 141 |
+
torch.rand(
|
| 142 |
+
n_batch, 1, n_frames, 1, device=tokens.device
|
| 143 |
+
) < self.p_token_dropout) & eligible[..., None]
|
| 144 |
+
tokens_emb = tokens_emb.masked_fill(drop, 0.0)
|
| 145 |
+
|
| 146 |
+
# Zero "ignored" features
|
| 147 |
+
feats = feats * feats_mask.transpose(1, 2)
|
| 148 |
+
|
| 149 |
+
# Sum embedded tokens across codebooks
|
| 150 |
+
tokens_emb = tokens_emb.sum(dim=1) # (n_batch, n_frames, n_channels)
|
| 151 |
+
|
| 152 |
+
# Sum embedded tokens and adapted features
|
| 153 |
+
x = torch.cat([feats, tokens_emb], dim=-1) # (n_batch, n_frames, 2 * n_channels)
|
| 154 |
+
x = self.in_proj(x) # (n_batch, n_frames, n_channels)
|
| 155 |
+
|
| 156 |
+
# Process with transformer
|
| 157 |
+
x = self.backbone(x=x) # (n_batch, n_frames, n_channels)
|
| 158 |
+
|
| 159 |
+
# Predict token logits
|
| 160 |
+
logits = self.head(x) # (n_batch, n_frames, n_codebooks * codebook_size)
|
| 161 |
+
logits = logits.reshape(
|
| 162 |
+
n_batch, n_frames, n_codebooks, self.codebook_size
|
| 163 |
+
).permute(0, 2, 1, 3) # (n_batch, n_codebooks, n_frames, codebook_size)
|
| 164 |
+
|
| 165 |
+
return logits
|
| 166 |
+
|
| 167 |
+
@torch.inference_mode()
|
| 168 |
+
def inference(
|
| 169 |
+
self,
|
| 170 |
+
tokens: torch.Tensor,
|
| 171 |
+
feats: torch.Tensor,
|
| 172 |
+
tokens_mask: torch.Tensor,
|
| 173 |
+
feats_mask: torch.Tensor,
|
| 174 |
+
top_p: Union[float, Iterable[float]] = 1.0,
|
| 175 |
+
top_k: Union[int, Iterable[int]] = None,
|
| 176 |
+
temp: Union[float, Iterable[float]] = 1.0,
|
| 177 |
+
mask_temp: Union[float, Iterable[float]] = 10.5,
|
| 178 |
+
iterations: Union[int, Iterable[int]] = 8,
|
| 179 |
+
guidance_scale: Union[float, Iterable[float]] = None,
|
| 180 |
+
causal_bias: Union[float, Iterable[float]] = None,
|
| 181 |
+
seed: Union[int, Iterable[int]] = None,
|
| 182 |
+
):
|
| 183 |
+
|
| 184 |
+
assert not self.training
|
| 185 |
+
device = next(iter(self.parameters())).device
|
| 186 |
+
|
| 187 |
+
# Avoid overwriting
|
| 188 |
+
tokens = tokens.clone().to(device)
|
| 189 |
+
tokens_mask = tokens_mask.clone().to(device)
|
| 190 |
+
|
| 191 |
+
assert tokens.ndim == 3
|
| 192 |
+
n_batch, n_codebooks, n_frames = tokens.shape
|
| 193 |
+
|
| 194 |
+
assert feats.ndim == 3
|
| 195 |
+
_, n_feats, _ = feats.shape
|
| 196 |
+
|
| 197 |
+
assert n_codebooks == self.n_codebooks
|
| 198 |
+
assert n_feats == self.n_feats
|
| 199 |
+
|
| 200 |
+
# Interpolate features to token resolution
|
| 201 |
+
feats = torch.nn.functional.interpolate(
|
| 202 |
+
feats.to(device), n_frames, mode=self.interp,
|
| 203 |
+
)
|
| 204 |
+
feats_mask = torch.nn.functional.interpolate(
|
| 205 |
+
feats_mask.unsqueeze(1).float().to(device), n_frames, mode="nearest",
|
| 206 |
+
).squeeze(1).to(feats_mask.dtype)
|
| 207 |
+
|
| 208 |
+
# Account for per-codebook args
|
| 209 |
+
def _to_codebooks(v):
|
| 210 |
+
if isinstance(v, torch.Tensor):
|
| 211 |
+
v = v.tolist()
|
| 212 |
+
elif isinstance(v, Iterable):
|
| 213 |
+
pass
|
| 214 |
+
else:
|
| 215 |
+
v = [v]
|
| 216 |
+
|
| 217 |
+
if len(v) == n_codebooks:
|
| 218 |
+
return v
|
| 219 |
+
elif len(v) == 1:
|
| 220 |
+
return v * n_codebooks
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Sampling parameters must be scalars, "
|
| 224 |
+
f"length-1 iterable, or length-n_codebooks ({n_codebooks})"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Construct `n_codebooks` state lists of length `n_batch` each
|
| 228 |
+
seed = seed or 0
|
| 229 |
+
if not isinstance(seed, Iterable):
|
| 230 |
+
seed = [seed]
|
| 231 |
+
assert len(seed) in [1, n_batch]
|
| 232 |
+
seed = seed * (n_batch // len(seed))
|
| 233 |
+
state = [format_seed([s + 10007 * cb for s in seed]) for cb in range(n_codebooks)]
|
| 234 |
+
|
| 235 |
+
top_p, top_k = _to_codebooks(top_p), _to_codebooks(top_k)
|
| 236 |
+
temp, mask_temp = _to_codebooks(temp), _to_codebooks(mask_temp)
|
| 237 |
+
iterations = _to_codebooks(iterations)
|
| 238 |
+
guidance_scale = _to_codebooks(guidance_scale)
|
| 239 |
+
causal_bias = _to_codebooks(causal_bias)
|
| 240 |
+
|
| 241 |
+
# Track initial masked token counts
|
| 242 |
+
n_masked_init = (~tokens_mask).long().sum(dim=-1) # (n_batch, n_codebooks)
|
| 243 |
+
|
| 244 |
+
# Generate one codebook at a time
|
| 245 |
+
for codebook_idx, (
|
| 246 |
+
_state, _top_p, _top_k, _temp, _mask_temp,
|
| 247 |
+
_iterations, _guidance_scale, _causal_bias,
|
| 248 |
+
) in enumerate(zip(
|
| 249 |
+
state, top_p, top_k, temp, mask_temp,
|
| 250 |
+
iterations, guidance_scale, causal_bias,
|
| 251 |
+
)):
|
| 252 |
+
_causal_bias = _causal_bias or 0.
|
| 253 |
+
assert 0. <= _causal_bias
|
| 254 |
+
|
| 255 |
+
_temp = _temp or 1.0
|
| 256 |
+
assert 0. < _temp
|
| 257 |
+
|
| 258 |
+
_mask_temp = _mask_temp or 0.0
|
| 259 |
+
assert 0. <= _mask_temp
|
| 260 |
+
|
| 261 |
+
_iterations = max(_iterations or 1, 1)
|
| 262 |
+
|
| 263 |
+
for _iter in range(_iterations):
|
| 264 |
+
|
| 265 |
+
# CFG on features by masking
|
| 266 |
+
if _guidance_scale:
|
| 267 |
+
tokens_cfg = torch.cat([tokens, tokens], dim=0)
|
| 268 |
+
tokens_mask_cfg = torch.cat([tokens_mask, tokens_mask], dim=0)
|
| 269 |
+
|
| 270 |
+
feats_cfg = torch.cat([feats, feats], dim=0)
|
| 271 |
+
feats_mask_cfg = torch.cat([feats_mask, torch.zeros_like(feats_mask)], dim=0)
|
| 272 |
+
|
| 273 |
+
logits_cond, logits_uncond = self.forward(
|
| 274 |
+
tokens_cfg,
|
| 275 |
+
feats_cfg,
|
| 276 |
+
torch.full(
|
| 277 |
+
(tokens_cfg.shape[0],),
|
| 278 |
+
codebook_idx,
|
| 279 |
+
dtype=torch.long,
|
| 280 |
+
device=device,
|
| 281 |
+
),
|
| 282 |
+
tokens_mask_cfg,
|
| 283 |
+
feats_mask_cfg,
|
| 284 |
+
).chunk(2, dim=0) # (n_batch, n_codebooks, n_frames, codebook_size) x2
|
| 285 |
+
|
| 286 |
+
logits = logits_uncond + _guidance_scale * (logits_cond - logits_uncond) # (n_batch, n_codebooks, n_frames, codebook_size)
|
| 287 |
+
|
| 288 |
+
else:
|
| 289 |
+
logits = self.forward(
|
| 290 |
+
tokens,
|
| 291 |
+
feats,
|
| 292 |
+
torch.full(
|
| 293 |
+
(tokens.shape[0],),
|
| 294 |
+
codebook_idx,
|
| 295 |
+
dtype=torch.long,
|
| 296 |
+
device=device,
|
| 297 |
+
),
|
| 298 |
+
tokens_mask,
|
| 299 |
+
feats_mask,
|
| 300 |
+
) # (n_batch, n_codebooks, n_frames, codebook_size)
|
| 301 |
+
|
| 302 |
+
# Truncate logits and sample tokens at masked positions
|
| 303 |
+
logits = top_p_top_k(
|
| 304 |
+
logits[:, codebook_idx:codebook_idx+1, ...], _top_p, _top_k
|
| 305 |
+
) # (n_batch, 1, n_frames, codebook_size)
|
| 306 |
+
sampled, probs = sample(
|
| 307 |
+
logits, _temp, argmax=(_iter==_iterations-1),
|
| 308 |
+
) # (n_batch, 1, n_frames) x2
|
| 309 |
+
write_idx = ~(tokens_mask[:, codebook_idx, :]) # (n_batch, n_frames)
|
| 310 |
+
tokens[:, codebook_idx, :][write_idx] = sampled[:, 0, :][write_idx]
|
| 311 |
+
|
| 312 |
+
# Compute implied generation timestep and corresponding target mask
|
| 313 |
+
# ratio
|
| 314 |
+
t = (_iter + 1) / _iterations
|
| 315 |
+
tgt_p_mask = cosine_schedule(torch.tensor([t]*n_batch, device=device)) # (n_batch,)
|
| 316 |
+
|
| 317 |
+
# Compute target and actual number of masked positions in current
|
| 318 |
+
# codebook
|
| 319 |
+
tgt_n_masked = torch.floor(tgt_p_mask * n_masked_init[:, codebook_idx]).long() # (n_batch,)
|
| 320 |
+
n_masked = write_idx.long().sum(dim=-1) # (n_batch,)
|
| 321 |
+
|
| 322 |
+
# Do not complete unmasking until final iteration, i.e. always leave at
|
| 323 |
+
# least one token unmasked
|
| 324 |
+
if _iter < _iterations - 1:
|
| 325 |
+
tgt_n_masked = torch.minimum(n_masked - 1, tgt_n_masked).clamp_min(1)
|
| 326 |
+
|
| 327 |
+
# Select which tokens to unmask via confidence (assigned probability),
|
| 328 |
+
# mediated by causal bias and random noise
|
| 329 |
+
_probs = torch.full_like(probs[:, 0, :], torch.inf) # (n_batch, n_frames)
|
| 330 |
+
_probs[write_idx] = probs[:, 0, :][write_idx]
|
| 331 |
+
tokens_mask[:, codebook_idx, :] = mask_by_confidence(
|
| 332 |
+
probs=_probs,
|
| 333 |
+
n=tgt_n_masked,
|
| 334 |
+
temp=_mask_temp * (1 - t), # Mask temperature annealing
|
| 335 |
+
causal_bias=_causal_bias or 0.0,
|
| 336 |
+
state=_state,
|
| 337 |
+
eligible=write_idx,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Re-apply span and codebook masks
|
| 341 |
+
tokens_mask = ~torch.logical_and(~tokens_mask, feats_mask.unsqueeze(1))
|
| 342 |
+
tokens_mask[:, :codebook_idx, :] = True
|
| 343 |
+
|
| 344 |
+
return tokens
|
tria/nn/__init__.py
ADDED
|
File without changes
|
tria/nn/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
tria/nn/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (6.67 kB). View file
|
|
|
tria/nn/__pycache__/norm.cpython-310.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
tria/nn/__pycache__/pos_enc.cpython-310.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
tria/nn/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
tria/nn/attention.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .norm import QKNorm
|
| 10 |
+
from .pos_enc import apply_rope
|
| 11 |
+
from .pos_enc import apply_sinusoidal
|
| 12 |
+
from .pos_enc import build_rope_cache
|
| 13 |
+
from .pos_enc import build_sinusoidal_cache
|
| 14 |
+
|
| 15 |
+
################################################################################
|
| 16 |
+
# Multihead attention operation
|
| 17 |
+
################################################################################
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def ensure_masks(
|
| 21 |
+
n_batch: int,
|
| 22 |
+
seq_len_q: int,
|
| 23 |
+
seq_len_k: int,
|
| 24 |
+
device,
|
| 25 |
+
mask_q: Optional[torch.Tensor],
|
| 26 |
+
mask_k: Optional[torch.Tensor],
|
| 27 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 28 |
+
"""
|
| 29 |
+
Parameters
|
| 30 |
+
----------
|
| 31 |
+
n_batch : int
|
| 32 |
+
seq_len_q : int
|
| 33 |
+
seq_len_k : int
|
| 34 |
+
mask_q : torch.Tensor
|
| 35 |
+
Shape (n_batch, seq_len_q)
|
| 36 |
+
mask_k : torch.Tensor
|
| 37 |
+
Shape (n_batch, seq_len_k)
|
| 38 |
+
"""
|
| 39 |
+
if mask_q is None:
|
| 40 |
+
mask_q = torch.ones(n_batch, seq_len_q, dtype=torch.bool, device=device)
|
| 41 |
+
if mask_k is None:
|
| 42 |
+
mask_k = torch.ones(n_batch, seq_len_k, dtype=torch.bool, device=device)
|
| 43 |
+
return mask_q, mask_k
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def make_attn_mask(
|
| 47 |
+
mask_q: torch.Tensor,
|
| 48 |
+
mask_k: torch.Tensor,
|
| 49 |
+
dtype,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Use "key padding mask" convention to prevent empty rows in attention score
|
| 53 |
+
matrix (and thus softmax issues).
|
| 54 |
+
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
mask_q : torch.Tensor
|
| 58 |
+
Query sequence mask, shape (n_batch, seq_len_q)
|
| 59 |
+
mask_k : torch.Tensor
|
| 60 |
+
Key sequence mask, shape (n_batch, seq_len_k)
|
| 61 |
+
|
| 62 |
+
Returns
|
| 63 |
+
-------
|
| 64 |
+
torch.Tensor
|
| 65 |
+
Additive attention mask for scaled_dot_product_attention, shape
|
| 66 |
+
(n_batch, 1, seq_len_q, seq_len_k)
|
| 67 |
+
"""
|
| 68 |
+
n_batch, seq_len_q = mask_q.shape
|
| 69 |
+
seq_len_k = mask_k.shape[1]
|
| 70 |
+
|
| 71 |
+
exclude = (
|
| 72 |
+
(~mask_k)[:, None, :].expand(n_batch, seq_len_q, seq_len_k).unsqueeze(1)
|
| 73 |
+
) # (n_batch, 1, seq_len_q, seq_len_k)
|
| 74 |
+
mask = exclude.to(dtype=dtype).masked_fill(exclude, float("-inf"))
|
| 75 |
+
|
| 76 |
+
return mask # (n_batch, 1, seq_len_q, seq_len_k)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def sdpa_with_fallback(
|
| 80 |
+
q: torch.Tensor,
|
| 81 |
+
k: torch.Tensor,
|
| 82 |
+
v: torch.Tensor,
|
| 83 |
+
attn_mask: Optional[torch.Tensor],
|
| 84 |
+
p_dropout: float,
|
| 85 |
+
training: bool,
|
| 86 |
+
use_sdpa: bool = True,
|
| 87 |
+
) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Optionally use PyTorch scaled_dot_product_attention (SDPA), which picks
|
| 90 |
+
efficient attention implementations (e.g. flash attention) if available
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
q : torch.Tensor
|
| 95 |
+
Query, shape (n_batch, n_heads, seq_len_q, head_channels)
|
| 96 |
+
k : torch.Tensor
|
| 97 |
+
Key, shape (n_batch, n_heads, seq_len_k, head_channels)
|
| 98 |
+
v : torch.Tensor
|
| 99 |
+
Value, shape (n_batch, n_heads, seq_len_k, head_channels)
|
| 100 |
+
attn_mask : torch.Tensor
|
| 101 |
+
Additive attention mask (0 or -inf), shape (n_batch, 1, seq_len_q, seq_len_k)
|
| 102 |
+
|
| 103 |
+
Returns
|
| 104 |
+
-------
|
| 105 |
+
torch.Tensor
|
| 106 |
+
Shape (n_batch, n_heads, seq_len_q, head_channels)
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
n_batch, n_heads, seq_len_q, head_channels = q.shape
|
| 110 |
+
seq_len_k = k.shape[2]
|
| 111 |
+
|
| 112 |
+
if use_sdpa and q.is_cuda:
|
| 113 |
+
if attn_mask is not None and (
|
| 114 |
+
(attn_mask.dtype == torch.bool and attn_mask.all())
|
| 115 |
+
or (attn_mask.dtype != torch.bool and not attn_mask.ne(0).any())
|
| 116 |
+
):
|
| 117 |
+
attn_mask = None
|
| 118 |
+
|
| 119 |
+
out = F.scaled_dot_product_attention(
|
| 120 |
+
q,
|
| 121 |
+
k,
|
| 122 |
+
v,
|
| 123 |
+
attn_mask=attn_mask,
|
| 124 |
+
dropout_p=p_dropout if training else 0.0,
|
| 125 |
+
is_causal=False,
|
| 126 |
+
)
|
| 127 |
+
return out
|
| 128 |
+
|
| 129 |
+
# Fallback
|
| 130 |
+
scale = 1.0 / math.sqrt(head_channels)
|
| 131 |
+
scores = torch.einsum("bhtd,bhsd->bhts", q, k) * scale
|
| 132 |
+
if attn_mask is not None:
|
| 133 |
+
scores = scores + attn_mask # Additive mask
|
| 134 |
+
attn = scores.softmax(dim=-1)
|
| 135 |
+
if training and p_dropout > 0.0:
|
| 136 |
+
attn = F.dropout(attn, p=p_dropout)
|
| 137 |
+
out = torch.einsum("bhts,bhsd->bhtd", attn, v)
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class MultiheadAttention(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
n_channels: int,
|
| 145 |
+
n_heads: int,
|
| 146 |
+
p_dropout: float = 0.0,
|
| 147 |
+
bias: bool = True,
|
| 148 |
+
max_len: int = 8192,
|
| 149 |
+
pos_enc: Optional[str] = "rope",
|
| 150 |
+
qk_norm: bool = True,
|
| 151 |
+
use_sdpa: bool = True,
|
| 152 |
+
):
|
| 153 |
+
super().__init__()
|
| 154 |
+
assert n_channels % n_heads == 0, "`n_channels` must be divisible by `n_heads`"
|
| 155 |
+
assert pos_enc in ("rope", "absolute", "none", None)
|
| 156 |
+
|
| 157 |
+
self.n_channels = n_channels
|
| 158 |
+
self.n_heads = n_heads
|
| 159 |
+
self.head_channels = n_channels // n_heads
|
| 160 |
+
self.p_dropout = p_dropout
|
| 161 |
+
self.pos_enc = pos_enc
|
| 162 |
+
self.max_len = max_len
|
| 163 |
+
self.use_sdpa = use_sdpa
|
| 164 |
+
|
| 165 |
+
self.q_proj = nn.Linear(n_channels, n_channels, bias=bias)
|
| 166 |
+
self.k_proj = nn.Linear(n_channels, n_channels, bias=bias)
|
| 167 |
+
self.v_proj = nn.Linear(n_channels, n_channels, bias=bias)
|
| 168 |
+
self.o_proj = nn.Linear(n_channels, n_channels, bias=bias)
|
| 169 |
+
|
| 170 |
+
self.o_dropout = nn.Dropout(p_dropout)
|
| 171 |
+
|
| 172 |
+
self.qk_norm = QKNorm(self.head_channels) if qk_norm else None
|
| 173 |
+
self.pos_cache = None
|
| 174 |
+
|
| 175 |
+
def _maybe_build_pos_cache(self, device, dtype):
|
| 176 |
+
if self.pos_enc in [None, "none"] or self.pos_cache is not None:
|
| 177 |
+
return
|
| 178 |
+
if self.pos_enc == "absolute":
|
| 179 |
+
self.pos_cache = build_sinusoidal_cache(
|
| 180 |
+
self.max_len, self.head_channels, device, dtype=torch.float32
|
| 181 |
+
)
|
| 182 |
+
elif self.pos_enc == "rope":
|
| 183 |
+
cos, sin = build_rope_cache(
|
| 184 |
+
self.max_len, self.head_channels, device, dtype=torch.float32
|
| 185 |
+
)
|
| 186 |
+
self.pos_cache = (cos, sin)
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self,
|
| 190 |
+
q: torch.Tensor,
|
| 191 |
+
k: torch.Tensor,
|
| 192 |
+
v: torch.Tensor,
|
| 193 |
+
mask_q: Optional[torch.Tensor] = None,
|
| 194 |
+
mask_k: Optional[torch.Tensor] = None,
|
| 195 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Parameters
|
| 199 |
+
----------
|
| 200 |
+
q : torch.Tensor
|
| 201 |
+
Query, shape (n_batch, seq_len_q, n_channels)
|
| 202 |
+
k : torch.Tensor
|
| 203 |
+
Key, shape (n_batch, seq_len_k, n_channels)
|
| 204 |
+
v : torch.Tensor
|
| 205 |
+
Value, shape (n_batch, seq_len_k, n_channels)
|
| 206 |
+
mask_q : torch.Tensor
|
| 207 |
+
Boolean mask, `True` for valid positions; shape (n_batch, seq_len_q)
|
| 208 |
+
mask_k : torch.Tensor
|
| 209 |
+
Boolean mask, `True` for valid positions; shape (n_batch, seq_len_k)
|
| 210 |
+
attn_mask : torch.tensor
|
| 211 |
+
Additive (0, -inf) mask; shape (n_batch, 1, seq_len_q, seq_len_k)
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
n_batch, seq_len_q, _ = q.shape
|
| 215 |
+
seq_len_k = k.shape[1]
|
| 216 |
+
device, dtype = q.device, q.dtype
|
| 217 |
+
|
| 218 |
+
# Projections (n_batch, seq_len, n_channels) -> (n_batch, n_heads, seq_len, head_channels)
|
| 219 |
+
q = (
|
| 220 |
+
self.q_proj(q)
|
| 221 |
+
.view(n_batch, seq_len_q, self.n_heads, self.head_channels)
|
| 222 |
+
.transpose(1, 2)
|
| 223 |
+
)
|
| 224 |
+
k = (
|
| 225 |
+
self.k_proj(k)
|
| 226 |
+
.view(n_batch, seq_len_k, self.n_heads, self.head_channels)
|
| 227 |
+
.transpose(1, 2)
|
| 228 |
+
)
|
| 229 |
+
v = (
|
| 230 |
+
self.v_proj(v)
|
| 231 |
+
.view(n_batch, seq_len_k, self.n_heads, self.head_channels)
|
| 232 |
+
.transpose(1, 2)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Positional encoding
|
| 236 |
+
self._maybe_build_pos_cache(device=device, dtype=dtype)
|
| 237 |
+
if self.pos_enc == "absolute":
|
| 238 |
+
cache = self.pos_cache # (max_seq_len, head_channels)
|
| 239 |
+
q = apply_sinusoidal(q, cache)
|
| 240 |
+
k = apply_sinusoidal(k, cache)
|
| 241 |
+
elif self.pos_enc == "rope":
|
| 242 |
+
cos, sin = self.pos_cache # (max_seq_len, head_channels/2)
|
| 243 |
+
q = apply_rope(q, cos, sin)
|
| 244 |
+
k = apply_rope(k, cos, sin)
|
| 245 |
+
|
| 246 |
+
# QK-Norm
|
| 247 |
+
if self.qk_norm is not None:
|
| 248 |
+
q, k = self.qk_norm(q, k)
|
| 249 |
+
|
| 250 |
+
# Masks
|
| 251 |
+
mask_q, mask_k = ensure_masks(
|
| 252 |
+
n_batch, seq_len_q, seq_len_k, device, mask_q, mask_k
|
| 253 |
+
)
|
| 254 |
+
pad_mask = make_attn_mask(
|
| 255 |
+
mask_q, mask_k, dtype
|
| 256 |
+
) # (n_batch, 1, seq_len_q, seq_len_k)
|
| 257 |
+
|
| 258 |
+
if attn_mask is not None:
|
| 259 |
+
pad_mask = pad_mask + attn_mask
|
| 260 |
+
|
| 261 |
+
# Attention
|
| 262 |
+
y = sdpa_with_fallback(
|
| 263 |
+
q,
|
| 264 |
+
k,
|
| 265 |
+
v,
|
| 266 |
+
attn_mask=pad_mask,
|
| 267 |
+
p_dropout=self.p_dropout,
|
| 268 |
+
training=self.training,
|
| 269 |
+
use_sdpa=self.use_sdpa,
|
| 270 |
+
) # (n_batch, n_heads, seq_len_q, head_channels)
|
| 271 |
+
|
| 272 |
+
y = y.transpose(1, 2).contiguous().view(n_batch, seq_len_q, self.n_channels)
|
| 273 |
+
y = self.o_proj(y) # (n_batch, seq_len_q, n_channels)
|
| 274 |
+
y = self.o_dropout(y)
|
| 275 |
+
|
| 276 |
+
# Mask outputs
|
| 277 |
+
if mask_q is not None:
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
y.masked_fill_(~mask_q[:, :, None], 0.0)
|
| 280 |
+
return y
|
tria/nn/norm.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
################################################################################
|
| 7 |
+
# Normalization layers
|
| 8 |
+
################################################################################
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RMSNorm(nn.Module):
|
| 12 |
+
def __init__(self, n_channels: int, eps: float = 1e-6):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.eps = eps
|
| 15 |
+
self.weight = nn.Parameter(torch.ones(n_channels))
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Normalize over final dimension
|
| 20 |
+
"""
|
| 21 |
+
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
| 22 |
+
return self.weight * x * rms # Broadcast targets final dimension
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class QKNorm(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
RMS-normalize query and key across channel dimension with a learnable gain.
|
| 28 |
+
Applied per-head, per-position.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, head_channels: int, eps: float = 1e-6):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.eps = eps
|
| 34 |
+
self.g_q = nn.Parameter(torch.ones(head_channels))
|
| 35 |
+
self.g_k = nn.Parameter(torch.ones(head_channels))
|
| 36 |
+
|
| 37 |
+
def forward(
|
| 38 |
+
self, q: torch.Tensor, k: torch.Tensor
|
| 39 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 40 |
+
"""
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
q : torch.Tensor
|
| 44 |
+
Query, shape (n_batch, n_heads, seq_len_q, head_channels)
|
| 45 |
+
k : torch.Tensor
|
| 46 |
+
Key, shape (n_batch, n_heads, seq_len_k, head_channels)
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def _rmsnorm(x, g):
|
| 50 |
+
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
| 51 |
+
return x * rms * g # Broadcast targets final dimension
|
| 52 |
+
|
| 53 |
+
return _rmsnorm(q, self.g_q), _rmsnorm(k, self.g_k)
|
tria/nn/pos_enc.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
################################################################################
|
| 4 |
+
# Utilities for positional encoding
|
| 5 |
+
################################################################################
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_sinusoidal_cache(seq_len: int, n_channels: int, device, dtype):
|
| 9 |
+
"""
|
| 10 |
+
Returns
|
| 11 |
+
-------
|
| 12 |
+
torch.Tensor
|
| 13 |
+
Cache, shape (seq_len, n_channels)
|
| 14 |
+
"""
|
| 15 |
+
assert n_channels % 2 == 0
|
| 16 |
+
pos = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1) # (seq_len, 1)
|
| 17 |
+
i = torch.arange(n_channels // 2, device=device, dtype=dtype).unsqueeze(
|
| 18 |
+
0
|
| 19 |
+
) # (1, n_channels/2)
|
| 20 |
+
inv_freq = 1.0 / (10000 ** (i / (n_channels // 2)))
|
| 21 |
+
ang = pos * inv_freq # (seq_len, n_channels/2)
|
| 22 |
+
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=1) # (seq_len, n_channels)
|
| 23 |
+
return emb
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def apply_sinusoidal(x: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
x : torch.Tensor
|
| 31 |
+
Shape (n_batch, n_heads, seq_len, head_channels) or (n_batch, seq_len, n_channels)
|
| 32 |
+
cache: torch.Tensor
|
| 33 |
+
Shape (seq_len, n_channels)
|
| 34 |
+
|
| 35 |
+
Returns
|
| 36 |
+
-------
|
| 37 |
+
torch.Tensor
|
| 38 |
+
Shape (n_batch, n_heads, seq_len, head_channels) or (n_batch, seq_len, n_channels)
|
| 39 |
+
"""
|
| 40 |
+
if x.ndim == 4:
|
| 41 |
+
n_batch, n_heads, seq_len, head_channels = x.shape
|
| 42 |
+
return x + cache.to(x.dtype)[None, None, :seq_len, :head_channels]
|
| 43 |
+
elif x.ndim == 3:
|
| 44 |
+
n_batch, seq_len, n_channels = x.shape
|
| 45 |
+
return x + cache.to(x.dtype)[None, :seq_len, :n_channels]
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"Invalid input shape {tuple(x.shape)}; "
|
| 49 |
+
f"expected (n_batch, [n_heads], seq_len, n_channels)"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def build_rope_cache(
|
| 54 |
+
seq_len: int, n_channels: int, device, dtype, base: float = 10000.0
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Returns
|
| 58 |
+
----------
|
| 59 |
+
torch.Tensor, torch.Tensor
|
| 60 |
+
Caches, shape (seq_len, n_channels/2)
|
| 61 |
+
"""
|
| 62 |
+
assert n_channels % 2 == 0
|
| 63 |
+
theta = 1.0 / (
|
| 64 |
+
base
|
| 65 |
+
** (torch.arange(0, n_channels, 2, device=device, dtype=dtype) / n_channels)
|
| 66 |
+
)
|
| 67 |
+
seq = torch.arange(seq_len, device=device, dtype=dtype)
|
| 68 |
+
freqs = torch.einsum("t,d->td", seq, theta) # (seq_len, n_channels/2)
|
| 69 |
+
return torch.cos(freqs), torch.sin(
|
| 70 |
+
freqs
|
| 71 |
+
) # (seq_len, n_channels/2), (seq_len, n_channels/2)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def apply_rope(
|
| 75 |
+
q_or_k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Parameters
|
| 79 |
+
----------
|
| 80 |
+
q_or_k : torch.Tensor
|
| 81 |
+
Shape (n_batch, n_heads, seq_len, head_channels) where head_channels even
|
| 82 |
+
cos : torch.Tensor
|
| 83 |
+
Shape (seq_len, head_channels/2)
|
| 84 |
+
sin : torch.Tensor
|
| 85 |
+
Shape (seq_len, head_channels/2)
|
| 86 |
+
|
| 87 |
+
Returns
|
| 88 |
+
-------
|
| 89 |
+
torch.Tensor
|
| 90 |
+
Shape (n_batch, n_heads, seq_len, head_channels)
|
| 91 |
+
"""
|
| 92 |
+
n_batch, n_heads, seq_len, head_channels = q_or_k.shape
|
| 93 |
+
q = q_or_k.reshape(n_batch, n_heads, seq_len, head_channels // 2, 2)
|
| 94 |
+
q1, q2 = q[..., 0], q[..., 1] # (n_batch, n_heads, seq_len, n_channels / 2)
|
| 95 |
+
c = cos[:seq_len].to(q_or_k.dtype)[None, None, :, :]
|
| 96 |
+
s = sin[:seq_len].to(q_or_k.dtype)[None, None, :, :]
|
| 97 |
+
out1 = q1 * c - q2 * s
|
| 98 |
+
out2 = q1 * s + q2 * c
|
| 99 |
+
return torch.stack([out1, out2], dim=-1).reshape(
|
| 100 |
+
n_batch, n_heads, seq_len, head_channels
|
| 101 |
+
)
|
tria/nn/transformer.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .attention import MultiheadAttention
|
| 8 |
+
from .norm import RMSNorm
|
| 9 |
+
|
| 10 |
+
################################################################################
|
| 11 |
+
# Transformer
|
| 12 |
+
################################################################################
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def lengths_to_mask(
|
| 16 |
+
lengths: torch.Tensor, max_len: Optional[int] = None
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
lengths : torch.Tensor
|
| 22 |
+
Shape (n_batch,)
|
| 23 |
+
max_len : int
|
| 24 |
+
"""
|
| 25 |
+
if max_len is None:
|
| 26 |
+
max_len = int(lengths.amax())
|
| 27 |
+
rng = torch.arange(max_len, device=lengths.device)
|
| 28 |
+
return rng[None, :] < lengths[:, None] # (n_batch, max_len)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MLP(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self, n_channels: int, mult: int = 4, p_dropout: float = 0.1, bias: bool = True
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.mlp = nn.Sequential(
|
| 38 |
+
nn.Linear(n_channels, n_channels * mult),
|
| 39 |
+
nn.GELU(),
|
| 40 |
+
nn.Linear(n_channels * mult, n_channels),
|
| 41 |
+
nn.Dropout(p_dropout),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor):
|
| 45 |
+
assert x.ndim == 3 # (n_batch, seq_len, n_channels)
|
| 46 |
+
return self.mlp(x) # (n_batch, seq_len, n_channels)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class TransformerBlock(nn.Module):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
n_channels: int,
|
| 53 |
+
n_heads: int,
|
| 54 |
+
mult: int = 4,
|
| 55 |
+
p_dropout: float = 0.0,
|
| 56 |
+
bias: bool = True,
|
| 57 |
+
max_len: int = 8192,
|
| 58 |
+
pos_enc_self_attn: Optional[str] = "rope",
|
| 59 |
+
pos_enc_cross_attn: Optional[str] = "absolute",
|
| 60 |
+
qk_norm: bool = True,
|
| 61 |
+
use_sdpa: bool = True,
|
| 62 |
+
cross_attn: bool = False,
|
| 63 |
+
norm: str = "layer",
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
assert norm in ["layer", "rms", "none", None]
|
| 68 |
+
if norm == "rms":
|
| 69 |
+
norm_cls = RMSNorm
|
| 70 |
+
elif norm == "layer":
|
| 71 |
+
norm_cls = nn.LayerNorm
|
| 72 |
+
else:
|
| 73 |
+
norm_cls = nn.Identity
|
| 74 |
+
|
| 75 |
+
self.norm_1 = norm_cls(n_channels)
|
| 76 |
+
self.self_attn = MultiheadAttention(
|
| 77 |
+
n_channels=n_channels,
|
| 78 |
+
n_heads=n_heads,
|
| 79 |
+
p_dropout=p_dropout,
|
| 80 |
+
bias=bias,
|
| 81 |
+
max_len=max_len,
|
| 82 |
+
pos_enc=pos_enc_self_attn,
|
| 83 |
+
qk_norm=qk_norm,
|
| 84 |
+
use_sdpa=use_sdpa,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.cross_attn = cross_attn
|
| 88 |
+
if cross_attn:
|
| 89 |
+
self.norm_x = norm_cls(n_channels)
|
| 90 |
+
self.norm_c = norm_cls(n_channels)
|
| 91 |
+
self.cross = MultiheadAttention(
|
| 92 |
+
n_channels=n_channels,
|
| 93 |
+
n_heads=n_heads,
|
| 94 |
+
p_dropout=p_dropout,
|
| 95 |
+
bias=bias,
|
| 96 |
+
max_len=max_len,
|
| 97 |
+
pos_enc=pos_enc_cross_attn,
|
| 98 |
+
qk_norm=qk_norm,
|
| 99 |
+
use_sdpa=use_sdpa,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.norm_2 = norm_cls(n_channels)
|
| 103 |
+
self.mlp = MLP(n_channels=n_channels, mult=mult, p_dropout=p_dropout, bias=bias)
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
x: torch.Tensor,
|
| 108 |
+
c: Optional[torch.Tensor] = None,
|
| 109 |
+
mask_x: Optional[torch.Tensor] = None,
|
| 110 |
+
mask_c: Optional[torch.Tensor] = None,
|
| 111 |
+
) -> torch.Tensor:
|
| 112 |
+
"""
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
x : torch.Tensor
|
| 116 |
+
Input sequence, shape (n_batch, seq_len_x, n_channels)
|
| 117 |
+
c : torch.Tensor
|
| 118 |
+
Conditioning sequence, shape (n_batch, seq_len_c, n_channels)
|
| 119 |
+
mask_x : torch.Tensor
|
| 120 |
+
Boolean mask indicating valid positions in input sequence, shape
|
| 121 |
+
(n_batch, seq_len_x)
|
| 122 |
+
mask_c : torch.Tensor
|
| 123 |
+
Boolean mask indicating valid positions in conditioning sequence,
|
| 124 |
+
shape (n_batch, seq_len_c)
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
if self.cross_attn:
|
| 128 |
+
assert c is not None
|
| 129 |
+
|
| 130 |
+
# Self-attention
|
| 131 |
+
y = self.norm_1(x)
|
| 132 |
+
y = self.self_attn(y, y, y, mask_q=mask_x, mask_k=mask_x)
|
| 133 |
+
x = x + y
|
| 134 |
+
|
| 135 |
+
# Cross-attention
|
| 136 |
+
if self.cross_attn and c is not None:
|
| 137 |
+
q = self.norm_x(x)
|
| 138 |
+
k = self.norm_c(c)
|
| 139 |
+
v = k
|
| 140 |
+
y = self.cross(q, k, v, mask_q=mask_x, mask_k=mask_c)
|
| 141 |
+
x = x + y
|
| 142 |
+
|
| 143 |
+
# MLP
|
| 144 |
+
y = self.norm_2(x)
|
| 145 |
+
y = self.mlp(y)
|
| 146 |
+
x = x + y
|
| 147 |
+
|
| 148 |
+
# Zero invalid outputs
|
| 149 |
+
if mask_x is not None:
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
x.masked_fill_(~mask_x[:, :, None], 0.0)
|
| 152 |
+
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Transformer(nn.Module):
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
n_channels: int,
|
| 160 |
+
n_heads: int,
|
| 161 |
+
n_layers: int,
|
| 162 |
+
mult: int,
|
| 163 |
+
p_dropout: float = 0.0,
|
| 164 |
+
bias: bool = True,
|
| 165 |
+
max_len: int = 8192,
|
| 166 |
+
pos_enc_self_attn: Optional[str] = "rope",
|
| 167 |
+
pos_enc_cross_attn: Optional[str] = "absolute",
|
| 168 |
+
qk_norm: bool = True,
|
| 169 |
+
use_sdpa: bool = True,
|
| 170 |
+
cross_attn: bool = False,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.layers = nn.ModuleList(
|
| 174 |
+
[
|
| 175 |
+
TransformerBlock(
|
| 176 |
+
n_channels=n_channels,
|
| 177 |
+
n_heads=n_heads,
|
| 178 |
+
mult=mult,
|
| 179 |
+
p_dropout=p_dropout,
|
| 180 |
+
bias=bias,
|
| 181 |
+
max_len=max_len,
|
| 182 |
+
pos_enc_self_attn=pos_enc_self_attn,
|
| 183 |
+
pos_enc_cross_attn=pos_enc_cross_attn,
|
| 184 |
+
qk_norm=qk_norm,
|
| 185 |
+
use_sdpa=use_sdpa,
|
| 186 |
+
cross_attn=cross_attn,
|
| 187 |
+
)
|
| 188 |
+
for _ in range(n_layers)
|
| 189 |
+
]
|
| 190 |
+
)
|
| 191 |
+
self.n_channels = n_channels
|
| 192 |
+
self.max_len = max_len
|
| 193 |
+
self.pos_enc_self_attn = pos_enc_self_attn
|
| 194 |
+
self.pos_enc_cross_attn = pos_enc_cross_attn
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def _masks_from_lengths(
|
| 198 |
+
self,
|
| 199 |
+
mask_x: Optional[torch.Tensor],
|
| 200 |
+
mask_c: Optional[torch.Tensor],
|
| 201 |
+
lengths_x: Optional[torch.Tensor],
|
| 202 |
+
lengths_c: Optional[torch.Tensor],
|
| 203 |
+
seq_len_x: int,
|
| 204 |
+
seq_len_c: Optional[int],
|
| 205 |
+
device,
|
| 206 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 207 |
+
if mask_x is None and lengths_x is not None:
|
| 208 |
+
mask_x = lengths_to_mask(lengths_x.to(device), seq_len_x)
|
| 209 |
+
if mask_c is None and lengths_c is not None:
|
| 210 |
+
assert seq_len_c is not None
|
| 211 |
+
mask_c = lengths_to_mask(lengths_c.to(device), seq_len_c)
|
| 212 |
+
if mask_x is not None:
|
| 213 |
+
mask_x = mask_x.bool()
|
| 214 |
+
if mask_c is not None:
|
| 215 |
+
mask_c = mask_c.bool()
|
| 216 |
+
return mask_x, mask_c
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
x: torch.Tensor,
|
| 221 |
+
c: Optional[torch.Tensor] = None,
|
| 222 |
+
mask_x: Optional[torch.Tensor] = None,
|
| 223 |
+
mask_c: Optional[torch.Tensor] = None,
|
| 224 |
+
lengths_x: Optional[torch.Tensor] = None,
|
| 225 |
+
lengths_c: Optional[torch.Tensor] = None,
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Parameters
|
| 229 |
+
----------
|
| 230 |
+
x : torch.Tensor
|
| 231 |
+
Input sequence, shape (n_batch, seq_len_x, n_channels)
|
| 232 |
+
c : torch.Tensor
|
| 233 |
+
Conditioning sequence, shape (n_batch, seq_len_c, n_channels)
|
| 234 |
+
mask_x : torch.Tensor
|
| 235 |
+
Boolean mask indicating valid positions in input sequence, shape
|
| 236 |
+
(n_batch, seq_len_x)
|
| 237 |
+
mask_c : torch.Tensor
|
| 238 |
+
Boolean mask indicating valid positions in conditioning sequence,
|
| 239 |
+
shape (n_batch, seq_len_c)
|
| 240 |
+
lengths_x : torch.Tensor
|
| 241 |
+
Valid lengths of input sequences, shape (n_batch,)
|
| 242 |
+
lengths_c : torch.Tensor
|
| 243 |
+
Valid lengths of conditioning sequences, shape (n_batch,)
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
assert x.ndim == 3
|
| 247 |
+
n_batch, seq_len_x, n_channels = x.shape
|
| 248 |
+
assert n_channels == self.n_channels
|
| 249 |
+
seq_len_c = c.shape[1] if c is not None else None
|
| 250 |
+
|
| 251 |
+
# Create valid masks from lengths if necessary
|
| 252 |
+
mask_x, mask_c = self._masks_from_lengths(
|
| 253 |
+
mask_x, mask_c, lengths_x, lengths_c, seq_len_x, seq_len_c, x.device
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
for i, block in enumerate(self.layers):
|
| 257 |
+
x = block(x=x, c=c, mask_x=mask_x, mask_c=mask_c)
|
| 258 |
+
|
| 259 |
+
return x
|
tria/pipelines/__init__.py
ADDED
|
File without changes
|
tria/pipelines/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (153 Bytes). View file
|
|
|
tria/pipelines/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .tokenizer import Tokenizer
|
| 2 |
+
from .tokenizer import TokenSequence
|
tria/pipelines/tokenizer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (242 Bytes). View file
|
|
|
tria/pipelines/tokenizer/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
tria/pipelines/tokenizer/dac/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023-present, Descript
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
tria/pipelines/tokenizer/dac/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .dac import DAC
|
tria/pipelines/tokenizer/dac/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
tria/pipelines/tokenizer/dac/__pycache__/dac.cpython-310.pyc
ADDED
|
Binary file (5.77 kB). View file
|
|
|
tria/pipelines/tokenizer/dac/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
tria/pipelines/tokenizer/dac/dac.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from .modules import Decoder
|
| 10 |
+
from .modules import Encoder
|
| 11 |
+
from .modules import init_weights
|
| 12 |
+
from .nn.quantize import ResidualVectorQuantize
|
| 13 |
+
|
| 14 |
+
################################################################################
|
| 15 |
+
# Descript Audio Codec (DAC)
|
| 16 |
+
################################################################################
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DAC(torch.nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Descript Audio Codec (DAC) proposed by Kumar et al. in "High-Fidelity Audio
|
| 22 |
+
Compression with Improved RVQGAN" (2023). Code adapted from:
|
| 23 |
+
https://github.com/descriptinc/descript-audio-codec
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
sample_rate: int = 44_100,
|
| 29 |
+
encoder_dim: int = 64,
|
| 30 |
+
encoder_rates: List[int] = (2, 4, 8, 8),
|
| 31 |
+
latent_dim: int = None,
|
| 32 |
+
decoder_dim: int = 1536,
|
| 33 |
+
decoder_rates: List[int] = (8, 8, 4, 2),
|
| 34 |
+
n_codebooks: int = 9,
|
| 35 |
+
codebook_size: int = 1024,
|
| 36 |
+
codebook_dim: Union[int, list] = 8,
|
| 37 |
+
quantizer_dropout: bool = False,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.encoder_dim = encoder_dim
|
| 42 |
+
self.encoder_rates = encoder_rates
|
| 43 |
+
self.decoder_dim = decoder_dim
|
| 44 |
+
self.decoder_rates = decoder_rates
|
| 45 |
+
self.sample_rate = sample_rate
|
| 46 |
+
|
| 47 |
+
if latent_dim is None:
|
| 48 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 49 |
+
self.latent_dim = latent_dim
|
| 50 |
+
|
| 51 |
+
self.hop_length = np.prod(encoder_rates)
|
| 52 |
+
|
| 53 |
+
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
| 54 |
+
|
| 55 |
+
self.n_codebooks = n_codebooks
|
| 56 |
+
self.codebook_size = codebook_size
|
| 57 |
+
self.codebook_dim = codebook_dim
|
| 58 |
+
self.quantizer = ResidualVectorQuantize(
|
| 59 |
+
input_dim=latent_dim,
|
| 60 |
+
n_codebooks=n_codebooks,
|
| 61 |
+
codebook_size=codebook_size,
|
| 62 |
+
codebook_dim=codebook_dim,
|
| 63 |
+
quantizer_dropout=quantizer_dropout,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.decoder = Decoder(
|
| 67 |
+
latent_dim,
|
| 68 |
+
decoder_dim,
|
| 69 |
+
decoder_rates,
|
| 70 |
+
)
|
| 71 |
+
self.apply(init_weights)
|
| 72 |
+
|
| 73 |
+
self.delay = self.get_delay()
|
| 74 |
+
|
| 75 |
+
# As long as we don't run chunked/segmented encoding and decoding,
|
| 76 |
+
# we can keep padding on
|
| 77 |
+
self.padding = True
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def padding(self):
|
| 81 |
+
if not hasattr(self, "_padding"):
|
| 82 |
+
self._padding = True
|
| 83 |
+
return self._padding
|
| 84 |
+
|
| 85 |
+
@padding.setter
|
| 86 |
+
def padding(self, value: bool):
|
| 87 |
+
assert isinstance(value, bool)
|
| 88 |
+
|
| 89 |
+
layers = [
|
| 90 |
+
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
for layer in layers:
|
| 94 |
+
if value:
|
| 95 |
+
if hasattr(layer, "original_padding"):
|
| 96 |
+
layer.padding = layer.original_padding
|
| 97 |
+
else:
|
| 98 |
+
layer.original_padding = layer.padding
|
| 99 |
+
layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
| 100 |
+
|
| 101 |
+
self._padding = value
|
| 102 |
+
|
| 103 |
+
def get_delay(self):
|
| 104 |
+
# Any number works here, delay is invariant to input length
|
| 105 |
+
l_out = self.get_output_length(0)
|
| 106 |
+
L = l_out
|
| 107 |
+
|
| 108 |
+
layers = []
|
| 109 |
+
for layer in self.modules():
|
| 110 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 111 |
+
layers.append(layer)
|
| 112 |
+
|
| 113 |
+
for layer in reversed(layers):
|
| 114 |
+
d = layer.dilation[0]
|
| 115 |
+
k = layer.kernel_size[0]
|
| 116 |
+
s = layer.stride[0]
|
| 117 |
+
|
| 118 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
| 119 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 120 |
+
elif isinstance(layer, nn.Conv1d):
|
| 121 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 122 |
+
|
| 123 |
+
L = math.ceil(L)
|
| 124 |
+
|
| 125 |
+
l_in = L
|
| 126 |
+
|
| 127 |
+
return (l_in - l_out) // 2
|
| 128 |
+
|
| 129 |
+
def get_output_length(self, input_length: int):
|
| 130 |
+
L = input_length
|
| 131 |
+
# Calculate output length
|
| 132 |
+
for layer in self.modules():
|
| 133 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 134 |
+
d = layer.dilation[0]
|
| 135 |
+
k = layer.kernel_size[0]
|
| 136 |
+
s = layer.stride[0]
|
| 137 |
+
|
| 138 |
+
if isinstance(layer, nn.Conv1d):
|
| 139 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 140 |
+
elif isinstance(layer, nn.ConvTranspose1d):
|
| 141 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 142 |
+
|
| 143 |
+
L = math.floor(L)
|
| 144 |
+
return L
|
| 145 |
+
|
| 146 |
+
def encode(
|
| 147 |
+
self,
|
| 148 |
+
audio_data: torch.Tensor,
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
Encode given audio data and return quantized latent codes.
|
| 152 |
+
|
| 153 |
+
Parameters
|
| 154 |
+
----------
|
| 155 |
+
audio_data : torch.Tensor
|
| 156 |
+
Audio data to encode, shape (batch_size, 1, n_samples)
|
| 157 |
+
|
| 158 |
+
Returns
|
| 159 |
+
-------
|
| 160 |
+
codes:
|
| 161 |
+
Codebook indices across all quantizer levels, shape
|
| 162 |
+
(n_batch, n_quantizers, n_frames)
|
| 163 |
+
z_O: torch.Tensor
|
| 164 |
+
Quantized output obtained by summing projected quantized residuals
|
| 165 |
+
(z_o) over all quantizer levels, shape (n_batch, latent_dim, n_frames)
|
| 166 |
+
z_i: torch.Tensor
|
| 167 |
+
Continuous representation of inputs projected into codebook space,
|
| 168 |
+
shape (n_batch, n_quantizers, codebook_dim, n_frames). Note that
|
| 169 |
+
each quantizer level represents a predicted residual.
|
| 170 |
+
z_q: torch.Tensor
|
| 171 |
+
Quantized representation of input in codebook space, shape
|
| 172 |
+
(n_batch, n_quantizers, codebook_dim, n_frames). Note that each
|
| 173 |
+
quantizer level represents a quantized predicted residual.
|
| 174 |
+
z_o: torch.Tensor
|
| 175 |
+
Continuous representation of quantized input, projected back into
|
| 176 |
+
latent space, shape (n_batch, n_quantizers, latent_dim, n_frames).
|
| 177 |
+
Note that each quantizer level represents a projected quantized
|
| 178 |
+
predicted residual.
|
| 179 |
+
"""
|
| 180 |
+
# Predict continuous latents
|
| 181 |
+
z = self.encoder(audio_data) # (n_batch, latent_dim, n_frames)
|
| 182 |
+
return *self.quantizer(z, n_quantizers=None), z
|
| 183 |
+
|
| 184 |
+
def decode(
|
| 185 |
+
self,
|
| 186 |
+
codes: torch.Tensor,
|
| 187 |
+
):
|
| 188 |
+
"""
|
| 189 |
+
Decode given quantized latent codes and return audio data
|
| 190 |
+
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
codes : torch.Tensor
|
| 194 |
+
Quantized latent codes, shape (n_batch, n_quantizers, n_frames)
|
| 195 |
+
|
| 196 |
+
Returns
|
| 197 |
+
-------
|
| 198 |
+
torch.Tensor
|
| 199 |
+
Decoded audio data, shape (n_batch, 1, n_samples)
|
| 200 |
+
"""
|
| 201 |
+
z_O = self.quantizer.from_codes(codes) # (n_batch, latent_dim, n_frames)
|
| 202 |
+
recons = self.decoder(z_O) # (n_batch, 1, n_samples)
|
| 203 |
+
return recons
|