saumyap29 commited on
Commit
c9f87fa
·
1 Parent(s): ecf18ad

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gradio/certificate.pem +31 -0
  3. README.md +1 -2
  4. app.py +192 -0
  5. pretrained/.gitignore +0 -0
  6. pretrained/tokenizer/dac/dac_44.1kHz_7.7kbps.pt +3 -0
  7. pretrained/tria/small_musdb_moises_2b/80000/extras.pt +3 -0
  8. pretrained/tria/small_musdb_moises_2b/80000/model.pt +3 -0
  9. pretrained/tria/small_musdb_moises_2b/best/extras.pt +3 -0
  10. pretrained/tria/small_musdb_moises_2b/best/model.pt +3 -0
  11. requirements.txt +11 -0
  12. tria/__init__.py +6 -0
  13. tria/__pycache__/__init__.cpython-310.pyc +0 -0
  14. tria/__pycache__/constants.cpython-310.pyc +0 -0
  15. tria/__pycache__/features.cpython-310.pyc +0 -0
  16. tria/__pycache__/util.cpython-310.pyc +0 -0
  17. tria/constants.py +11 -0
  18. tria/data/__init__.py +0 -0
  19. tria/data/dataset.py +280 -0
  20. tria/data/preprocess.py +124 -0
  21. tria/features.py +187 -0
  22. tria/model/__init__.py +1 -0
  23. tria/model/__pycache__/__init__.cpython-310.pyc +0 -0
  24. tria/model/__pycache__/mask.cpython-310.pyc +0 -0
  25. tria/model/__pycache__/sample.cpython-310.pyc +0 -0
  26. tria/model/__pycache__/tria.cpython-310.pyc +0 -0
  27. tria/model/mask.py +263 -0
  28. tria/model/sample.py +168 -0
  29. tria/model/tria.py +344 -0
  30. tria/nn/__init__.py +0 -0
  31. tria/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  32. tria/nn/__pycache__/attention.cpython-310.pyc +0 -0
  33. tria/nn/__pycache__/norm.cpython-310.pyc +0 -0
  34. tria/nn/__pycache__/pos_enc.cpython-310.pyc +0 -0
  35. tria/nn/__pycache__/transformer.cpython-310.pyc +0 -0
  36. tria/nn/attention.py +280 -0
  37. tria/nn/norm.py +53 -0
  38. tria/nn/pos_enc.py +101 -0
  39. tria/nn/transformer.py +259 -0
  40. tria/pipelines/__init__.py +0 -0
  41. tria/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  42. tria/pipelines/tokenizer/__init__.py +2 -0
  43. tria/pipelines/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
  44. tria/pipelines/tokenizer/__pycache__/tokenizer.cpython-310.pyc +0 -0
  45. tria/pipelines/tokenizer/dac/LICENSE +21 -0
  46. tria/pipelines/tokenizer/dac/__init__.py +1 -0
  47. tria/pipelines/tokenizer/dac/__pycache__/__init__.cpython-310.pyc +0 -0
  48. tria/pipelines/tokenizer/dac/__pycache__/dac.cpython-310.pyc +0 -0
  49. tria/pipelines/tokenizer/dac/__pycache__/modules.cpython-310.pyc +0 -0
  50. tria/pipelines/tokenizer/dac/dac.py +203 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -8,6 +8,5 @@ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Audio Prompted Drums Generation
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from pathlib import Path
5
+ from audiotools import AudioSignal
6
+ from tria.model.tria import TRIA
7
+ from tria.pipelines.tokenizer import Tokenizer
8
+ from tria.features import rhythm_features
9
+ from functools import partial
10
+ from pyharp.core import ModelCard, build_endpoint
11
+ from pyharp.media.audio import load_audio, save_audio
12
+ from pyharp.labels import LabelList
13
+
14
+ # Global Config
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ N_OUTPUTS = 3
17
+
18
+ # Model Zoo
19
+ MODEL_ZOO = {
20
+ "small_musdb_moises_2b": {
21
+ "checkpoint": "pretrained/tria/small_musdb_moises_2b/80000/model.pt",
22
+ "model_cfg": {
23
+ "codebook_size": 1024,
24
+ "n_codebooks": 9,
25
+ "n_channels": 512,
26
+ "n_feats": 2,
27
+ "n_heads": 8,
28
+ "n_layers": 12,
29
+ "mult": 4,
30
+ "p_dropout": 0.0,
31
+ "bias": True,
32
+ "max_len": 1000,
33
+ "pos_enc": "rope",
34
+ "qk_norm": True,
35
+ "use_sdpa": True,
36
+ "interp": "nearest",
37
+ "share_emb": True,
38
+ },
39
+ "tokenizer_cfg": {"name": "dac"},
40
+ "feature_cfg": {
41
+ "sample_rate": 16_000,
42
+ "n_bands": 2,
43
+ "n_mels": 40,
44
+ "window_length": 384,
45
+ "hop_length": 192,
46
+ "quantization_levels": 5,
47
+ "slow_ma_ms": 200,
48
+ "post_smooth_ms": 100,
49
+ "legacy_normalize": False,
50
+ "clamp_max": 50.0,
51
+ "normalize_quantile": 0.98,
52
+ },
53
+ "infer_cfg": {
54
+ "top_p": 0.95,
55
+ "top_k": None,
56
+ "temp": 1.0,
57
+ "mask_temp": 10.5,
58
+ "iterations": [8, 8, 8, 8, 4, 4, 4, 4, 4],
59
+ "guidance_scale": 2.0,
60
+ "causal_bias": 1.0,
61
+ },
62
+ "max_duration": 6.0,
63
+ },
64
+ }
65
+
66
+ # Loaded model cache
67
+ LOADED = dict(name=None, model=None, tokenizer=None, feature_fn=None, infer_cfg=None, sample_rate=None, max_duration=None)
68
+
69
+ # Model loading
70
+ def load_model_by_name(name: str):
71
+ """Load a TRIA model by name (cached)."""
72
+ if LOADED["name"] == name and LOADED["model"] is not None:
73
+ return LOADED["model"]
74
+
75
+ cfg = MODEL_ZOO[name]
76
+ model = TRIA(**cfg["model_cfg"])
77
+ sd = torch.load(cfg["checkpoint"], map_location="cpu")
78
+ model.load_state_dict(sd, strict=True)
79
+ model.to(DEVICE).eval()
80
+
81
+ tokenizer = Tokenizer(**cfg["tokenizer_cfg"]).to(DEVICE)
82
+ feat_fn = partial(rhythm_features, **cfg.get("feature_cfg", {}))
83
+
84
+ LOADED.update(
85
+ dict(
86
+ name=name,
87
+ model=model,
88
+ tokenizer=tokenizer,
89
+ feature_fn=feat_fn,
90
+ infer_cfg=cfg["infer_cfg"],
91
+ sample_rate=tokenizer.sample_rate,
92
+ max_duration=cfg["max_duration"],
93
+ )
94
+ )
95
+ return model
96
+
97
+
98
+ # Inference logic
99
+ @spaces.GPU
100
+ @torch.inference_mode()
101
+ def generate_audio(model_name, timbre_path, rhythm_path, cfg_scale, top_p, mask_temperature, seed):
102
+ model = load_model_by_name(model_name)
103
+ tokenizer = LOADED["tokenizer"]
104
+ feat_fn = LOADED["feature_fn"]
105
+ sample_rate = LOADED["sample_rate"]
106
+ infer_cfg = LOADED["infer_cfg"]
107
+
108
+ timbre_sig = load_audio(timbre_path).resample(sample_rate)
109
+ rhythm_sig = load_audio(rhythm_path).resample(sample_rate)
110
+ timbre_sig.ensure_max_of_audio()
111
+ rhythm_sig.ensure_max_of_audio()
112
+
113
+ prefix_dur = int(LOADED["max_duration"] / 3)
114
+ timbre_tokens = tokenizer.encode(timbre_sig)
115
+ rhythm_tokens = tokenizer.encode(rhythm_sig)
116
+ tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
117
+ n_batch, n_codebooks, n_frames = tokens.shape
118
+ prefix_frames = timbre_tokens.tokens.shape[-1]
119
+
120
+ feats = feat_fn(rhythm_sig)
121
+ feats = torch.nn.functional.interpolate(feats, n_frames - prefix_frames, mode=model.interp)
122
+ full_feats = torch.zeros(n_batch, feats.shape[1], n_frames, device=DEVICE)
123
+ full_feats[..., prefix_frames:] = feats
124
+
125
+ prefix_mask = torch.arange(n_frames, device=DEVICE)[None, :].repeat(n_batch, 1) < prefix_frames
126
+ buffer_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
127
+ feats_mask = ~prefix_mask
128
+
129
+ outputs = []
130
+ for i in range(N_OUTPUTS):
131
+ torch.manual_seed(seed + i)
132
+ gen = model.inference(
133
+ tokens.clone().to(DEVICE),
134
+ full_feats.to(DEVICE),
135
+ buffer_mask.clone().to(DEVICE),
136
+ feats_mask.to(DEVICE),
137
+ top_p=float(top_p),
138
+ mask_temp=float(mask_temperature),
139
+ iterations=infer_cfg["iterations"],
140
+ guidance_scale=float(cfg_scale),
141
+ )[..., prefix_frames:]
142
+
143
+ rhythm_tokens.tokens = gen
144
+ out_sig = tokenizer.decode(rhythm_tokens)
145
+ out_sig.ensure_max_of_audio()
146
+ output_path = f"tria_out_{i+1}.wav"
147
+ save_audio(out_sig, output_path)
148
+ path_i = output_path
149
+ outputs.append(str(path_i))
150
+ return tuple(outputs)
151
+
152
+
153
+ # PyHARP Metadata
154
+ model_card = ModelCard(
155
+ name="TRIA: The Rhythm In Anything",
156
+ description=(
157
+ "Transform your rhythmic ideas into full drum performances. TRIA takes two short audio prompts: \n "
158
+ "Rhythm Prompt (tapping, beatboxing, or percussion gesture) "
159
+ "and a Timbre Prompt (an example drum sound or kit recording) \n "
160
+ "It generates 3 drum arrangements that match your groove and chosen timbre. "
161
+ ),
162
+ author="Patrick O'Reilly, Julia Barnett, Hugo Flores García, Annie Chu, Nathan Pruyne, Prem Seetharaman, Bryan Pardo",
163
+ tags=["tria", "rhythm-generation", "pyharp"],
164
+ )
165
+
166
+
167
+ # Gradio and PyHARP Endpoint
168
+ with gr.Blocks(title="TRIA") as demo:
169
+ timbre_in = gr.Audio(type="filepath", label="Timbre Prompt").harp_required(True)
170
+ rhythm_in = gr.Audio(type="filepath", label="Rhythm Prompt").harp_required(True)
171
+
172
+ model_names = list(MODEL_ZOO.keys())
173
+ model_dropdown = gr.Dropdown(choices=model_names, value=model_names[0], label="Model")
174
+
175
+ with gr.Row():
176
+ cfg_scale = gr.Slider(0.0, 10.0, value=2.0, step=0.1, label="CFG Scale")
177
+ top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top P")
178
+ mask_temperature = gr.Slider(0.0, 20.0, value=10.5, step=0.1, label="Mask Temperature")
179
+ seed = gr.Slider(0, 1000, value=0, step=1, label="Random Seed")
180
+
181
+ out1 = gr.Audio(type="filepath", label="Generated #1")
182
+ out2 = gr.Audio(type="filepath", label="Generated #2")
183
+ out3 = gr.Audio(type="filepath", label="Generated #3")
184
+
185
+ app = build_endpoint(
186
+ model_card=model_card,
187
+ input_components=[model_dropdown, timbre_in, rhythm_in, cfg_scale, top_p, mask_temperature, seed],
188
+ output_components=[out1, out2, out3],
189
+ process_fn=generate_audio,
190
+ )
191
+
192
+ demo.queue().launch(share=True, show_error=True)
pretrained/.gitignore ADDED
File without changes
pretrained/tokenizer/dac/dac_44.1kHz_7.7kbps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ffa16e9cd52d67dadef026823403481930942f3fead32f44b75c4b60627246a
3
+ size 306721572
pretrained/tria/small_musdb_moises_2b/80000/extras.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e18d9b8dbf5c5ff0d86aaf04d2af014960d97eeb396f7743e7595692ee31b68
3
+ size 344556763
pretrained/tria/small_musdb_moises_2b/80000/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e20c3850253ba7fb267440573137f4b6099cad1e437fcfd574b84d60138155c
3
+ size 172260091
pretrained/tria/small_musdb_moises_2b/best/extras.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e18d9b8dbf5c5ff0d86aaf04d2af014960d97eeb396f7743e7595692ee31b68
3
+ size 344556763
pretrained/tria/small_musdb_moises_2b/best/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e20c3850253ba7fb267440573137f4b6099cad1e437fcfd574b84d60138155c
3
+ size 172260091
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.9.0
2
+ torchaudio==2.9.0
3
+ numpy
4
+ argbind
5
+ descript-audiotools>=0.9.2
6
+ pyharp>=1.7.8
7
+ gradio>=4.42.0
8
+ librosa
9
+ soundfile
10
+ tqdm
11
+
tria/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __version__ = "0.0.1"
2
+
3
+ from . import constants
4
+ from . import util
5
+ from . import features
6
+ from . import transforms
tria/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (294 Bytes). View file
 
tria/__pycache__/constants.cpython-310.pyc ADDED
Binary file (465 Bytes). View file
 
tria/__pycache__/features.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
tria/__pycache__/util.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
tria/constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ MANIFESTS_DIR = Path(__file__).parent.parent / "manifests"
4
+ DATA_DIR = Path(__file__).parent.parent / "data"
5
+ PRETRAINED_DIR = Path(__file__).parent.parent / "pretrained"
6
+ ASSETS_DIR = Path(__file__).parent.parent / "assets"
7
+
8
+
9
+ STEMS = ["drums", "bass", "vocals", "other", "mixture"]
10
+ SAMPLE_RATE = 44_100
11
+ DURATION = 6.0
tria/data/__init__.py ADDED
File without changes
tria/data/dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from pathlib import Path
3
+ from typing import Callable
4
+ from typing import Dict
5
+ from typing import List
6
+ from typing import Optional
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+ import soundfile as sf
11
+ from audiotools import AudioSignal
12
+ from audiotools.core.util import random_state
13
+ from torch.utils.data import Dataset
14
+
15
+ from ..constants import DURATION
16
+ from ..constants import SAMPLE_RATE
17
+ from ..constants import STEMS
18
+ from ..util import collate
19
+ from ..util import get_info
20
+ from ..util import load_audio
21
+ from ..util import rms_salience
22
+
23
+ ################################################################################
24
+ # Dataset for loading aligned excerpts across stem classes
25
+ ################################################################################
26
+
27
+
28
+ class StemDataset(Dataset):
29
+ """
30
+ Load aligned excerpts from specified stem classes given paths in one or more
31
+ CSV manifests. Based on `audiotools.data.datasets.AudioDataset`.
32
+
33
+ Parameters
34
+ ----------
35
+ sources : Union[str, Path, List[Union[str, Path]]]
36
+ CSV manifest(s) with columns for each requested stem.
37
+ stems : List[str]
38
+ Column names to load, e.g. ["mixture","drums","bass","vocals"].
39
+ The **first** stem is used for salience unless `salience_on` is set.
40
+ sample_rate : int
41
+ duration : float
42
+ n_examples : int
43
+ num_channels : int
44
+ relative_path : str
45
+ Prepended to relative CSV paths.
46
+ strict : bool
47
+ Drop rows with missing stems (True) vs. fill with silence (False).
48
+ with_replacement : bool
49
+ Sampling strategy for rows.
50
+ shuffle_state : int
51
+ Seed for deterministic per-index RNG.
52
+ loudness_cutoff : Optional[float]
53
+ dB LUFS cutoff; if None, take random excerpt (still shared across stems).
54
+ salience_num_tries : int
55
+ Max tries for salient excerpt search (see `AudioSignal.salient_excerpt`).
56
+ salience_on : Optional[str]
57
+ Which stem to use for salience. Defaults to first of `stems`.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ stems: List[str] = STEMS,
63
+ sample_rate: int = SAMPLE_RATE,
64
+ duration: float = DURATION,
65
+ sources: Union[str, Path, List[Union[str, Path]]] = None,
66
+ source_weights: Optional[List[float]] = None,
67
+ n_examples: int = 1000,
68
+ num_channels: int = 1,
69
+ relative_path: str = "",
70
+ strict: bool = True,
71
+ with_replacement: bool = True,
72
+ shuffle_state: int = 0,
73
+ loudness_cutoff: Optional[float] = -40.0,
74
+ salience_num_tries: int = 8,
75
+ salience_on: Optional[str] = None,
76
+ ):
77
+ super().__init__()
78
+
79
+ assert sources is not None
80
+ assert len(stems) >= 1
81
+
82
+ self.stems = list(stems)
83
+ self.sample_rate = int(sample_rate)
84
+ self.duration = float(duration)
85
+ self.num_channels = int(num_channels)
86
+ self.relative_path = Path(relative_path)
87
+ self.strict = strict
88
+ self.with_replacement = with_replacement
89
+ self.length = int(n_examples)
90
+ self.shuffle_state = int(shuffle_state)
91
+
92
+ self.loudness_cutoff = loudness_cutoff
93
+ self.salience_num_tries = int(salience_num_tries)
94
+ self.salience_on = salience_on or self.stems[0]
95
+ if self.salience_on not in self.stems:
96
+ raise ValueError(
97
+ f"`salience_on` ('{self.salience_on}') must be one of {self.stems}"
98
+ )
99
+
100
+ # Read manifests
101
+ csv_paths = [sources] if isinstance(sources, (str, Path)) else list(sources)
102
+ self.source_rows: List[List[Dict]] = []
103
+ kept_mask: List[bool] = []
104
+ kept_csvs: List[Path] = []
105
+
106
+ for cpath in csv_paths:
107
+ # Read rows for source
108
+ cpath = Path(cpath)
109
+ raw_rows = []
110
+ with open(cpath, "r") as f:
111
+ reader = csv.DictReader(f)
112
+ for row in reader:
113
+ entry = {"__manifest__": str(cpath)}
114
+ stem_paths = {}
115
+ for s in self.stems:
116
+ raw = (row.get(s) or "").strip()
117
+ stem_paths[s] = str(self._resolve_path(raw)) if raw else ""
118
+ entry["paths"] = stem_paths
119
+ extra = {k: v for k, v in row.items() if k not in self.stems}
120
+ if extra:
121
+ entry["meta"] = extra
122
+ raw_rows.append(entry)
123
+
124
+ # Filter rows for source
125
+ filtered = []
126
+ for r in raw_rows:
127
+ missing = [
128
+ s for s, p in r["paths"].items() if not p or not Path(p).is_file()
129
+ ]
130
+ if self.strict and missing:
131
+ continue
132
+
133
+ min_dur = np.inf
134
+ any_valid = False
135
+ for s, p in r["paths"].items():
136
+ if p and Path(p).is_file():
137
+ any_valid = True
138
+ try:
139
+ total_sec = float(sf.info(p).duration)
140
+ min_dur = min(min_dur, float(total_sec))
141
+ except Exception:
142
+ if self.strict:
143
+ min_dur = -np.inf
144
+ break
145
+ if not any_valid or not np.isfinite(min_dur):
146
+ continue
147
+ if min_dur < self.duration and self.strict:
148
+ continue
149
+
150
+ r["min_duration"] = min_dur if np.isfinite(min_dur) else 0.0
151
+ filtered.append(r)
152
+
153
+ if len(filtered) > 0:
154
+ self.source_rows.append(filtered)
155
+ kept_mask.append(True)
156
+ kept_csvs.append(cpath)
157
+ else:
158
+ kept_mask.append(False)
159
+
160
+ if len(self.source_rows) == 0:
161
+ raise RuntimeError(
162
+ "StemDataset: no valid rows after filtering in any source."
163
+ )
164
+
165
+ self.csv_paths = kept_csvs
166
+
167
+ lengths = [len(lst) for lst in self.source_rows]
168
+ self._source_offsets = np.cumsum([0] + lengths[:-1]) # for global idx
169
+ self._n_rows = int(sum(lengths))
170
+
171
+ # Weights over non-empty sources
172
+ if source_weights is None:
173
+ self._weights = None
174
+ else:
175
+ if len(source_weights) != len(csv_paths):
176
+ raise ValueError(
177
+ f"source_weights must match number of sources ({len(csv_paths)}), "
178
+ f"got {len(source_weights)}"
179
+ )
180
+ w = np.asarray(source_weights, dtype=float)
181
+ # Keep only weights for sources that survived filtering
182
+ w = w[np.array(kept_mask, dtype=bool)]
183
+ w = np.clip(w, 0, None)
184
+ if not np.any(w > 0):
185
+ w = np.ones_like(w)
186
+ self._weights = (w / w.sum()).tolist()
187
+
188
+ def _resolve_path(self, p: Union[str, Path]) -> Path:
189
+ p = Path(p).expanduser()
190
+ if not p.is_absolute():
191
+ p = (self.relative_path / p).expanduser()
192
+ return p
193
+
194
+ def _pick_row(self, state: np.random.RandomState):
195
+ # Sample a non-empty source
196
+ sidx = int(state.choice(len(self.source_rows), p=self._weights))
197
+ n_in_source = len(self.source_rows[sidx])
198
+ item_idx = int(state.randint(n_in_source))
199
+ row = self.source_rows[sidx][item_idx]
200
+
201
+ # Map to a global idx for metadata
202
+ ridx_global = int(self._source_offsets[sidx] + item_idx)
203
+ return ridx_global, row
204
+
205
+ def __len__(self):
206
+ return self.length
207
+
208
+ def __getitem__(self, idx: int):
209
+ state = random_state((self.shuffle_state + int(idx)) & 0x7FFFFFFF)
210
+ ridx, row = self._pick_row(state)
211
+
212
+ primary = self.salience_on
213
+ p0 = row["paths"].get(primary, "")
214
+
215
+ offset = 0.0
216
+ primary_sig = None
217
+ if p0 and Path(p0).is_file():
218
+ if self.loudness_cutoff is None or not self.salience_num_tries:
219
+ try:
220
+ total_sec, _sr = get_info(p0)
221
+ except Exception:
222
+ total_sec = 0.0
223
+ max_off = max(0.0, total_sec - self.duration)
224
+ offset = float(state.rand() * max_off) if max_off > 0 else 0.0
225
+ else:
226
+ offset = rms_salience(
227
+ p0,
228
+ duration=self.duration,
229
+ cutoff_db=float(self.loudness_cutoff),
230
+ num_tries=int(self.salience_num_tries),
231
+ state=state,
232
+ )
233
+ primary_sig = load_audio(p0, offset=offset, duration=self.duration)
234
+ else:
235
+ offset = 0.0
236
+
237
+ item: Dict[str, Dict] = {}
238
+ for s in self.stems:
239
+ p = row["paths"][s]
240
+ exists = bool(p) and Path(p).is_file()
241
+
242
+ if s == primary and primary_sig is not None:
243
+ sig = primary_sig.clone() # reuse window we already loaded
244
+ elif exists:
245
+ sig = load_audio(
246
+ p, offset=offset, duration=self.duration
247
+ ) # windowed load
248
+ else:
249
+ sig = AudioSignal.zeros(
250
+ self.duration, self.sample_rate, self.num_channels
251
+ )
252
+
253
+ # Channel formatting
254
+ if self.num_channels == 1:
255
+ sig = sig.to_mono()
256
+ elif self.num_channels != sig.num_channels:
257
+ assert sig.num_channels == 1
258
+ sig.audio_data = sig.audio_data.repeat(1, self.num_channels, 1)
259
+
260
+ # Resample/pad to target SR and exact duration
261
+ sig = sig.resample(self.sample_rate)
262
+ if sig.duration < self.duration:
263
+ sig = sig.zero_pad_to(int(self.duration * self.sample_rate))
264
+
265
+ # Metadata
266
+ sig.metadata["path"] = p
267
+ sig.metadata["offset"] = offset
268
+ sig.metadata["source_row"] = ridx
269
+ if "meta" in row:
270
+ for k, v in row["meta"].items():
271
+ sig.metadata[k] = v
272
+
273
+ item[s] = {"signal": sig, "path": p}
274
+
275
+ item["idx"] = idx
276
+ return item
277
+
278
+ @staticmethod
279
+ def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
280
+ return collate(list_of_dicts, n_splits=n_splits)
tria/data/preprocess.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, Tuple, Union, Optional, Any
5
+ from rich.progress import track
6
+
7
+ import numpy as np
8
+
9
+ from audiotools.core.util import random_state
10
+ from ..util import ensure_dir
11
+
12
+ SplitType = Union[Tuple[float, float, float], Callable[[Path], str]]
13
+
14
+
15
+ def create_manifests(
16
+ data_dir: Union[str, Path],
17
+ ext: str,
18
+ output_dir: Union[str, Path],
19
+ split: SplitType,
20
+ attributes: Dict[str, Callable[[Path], Any]],
21
+ seed: Optional[int] = 0,
22
+ ) -> Dict[str, Path]:
23
+ """
24
+ Create CSV manifests for audio dataset.
25
+
26
+ Parameters
27
+ ----------
28
+ data_dir : str
29
+ Dataset root directory to search recursively for files
30
+ ext : str
31
+ Audio file extension
32
+ output_dir : str
33
+ Directory to which to write manifests
34
+ split : SplitType
35
+ Either a 3-tuple containing (train, val, test) proportions summing to 1
36
+ or a Callable that returns "train", "val", or "test" given a filepath
37
+ attributes : dict
38
+ Dictionary mapping column names to Callables for extracting values
39
+ given filepaths; for example {'path': lambda p: str(p)}
40
+ seed : int
41
+ Random seed
42
+ """
43
+ data_dir = Path(data_dir)
44
+ output_dir = Path(output_dir)
45
+ ensure_dir(output_dir)
46
+
47
+ all_files = sorted(
48
+ [p for p in data_dir.rglob(f"*{ext}") if p.is_file()],
49
+ key=lambda p: str(p).lower(),
50
+ )
51
+
52
+ splits = {"train": [], "val": [], "test": []}
53
+
54
+ # Callable split: apply given function to file paths to obtain train/val/test
55
+ # assignments
56
+ if callable(split):
57
+ for p in all_files:
58
+ s = split(p)
59
+ if s not in splits:
60
+ raise ValueError(
61
+ f"Split function must return one of "
62
+ f"{list(splits.keys())}, got {s!r} for {p}"
63
+ )
64
+ splits[s].append(p)
65
+
66
+ # Proportional split: randomly shuffle files and split according to given
67
+ # values
68
+ else:
69
+ if not (isinstance(split, tuple) and len(split) == 3):
70
+ raise ValueError(f"Split proportions tuple must have length 3")
71
+ p_train, p_val, p_test = split
72
+ total = float(p_train + p_val + p_test)
73
+ if not np.isclose(total, 1.0, atol=1e-6):
74
+ raise ValueError(f"Split proportions must sum to 1.0 (got {total}).")
75
+
76
+ rs = random_state(seed)
77
+ idx = np.array(rs.permutation(len(all_files)))
78
+ n = len(idx)
79
+ n_train = int(np.floor(p_train * n))
80
+ n_val = int(np.floor(p_val * n))
81
+ n_test = n - n_train - n_val
82
+
83
+ train_idx = idx[:n_train]
84
+ val_idx = idx[n_train:n_train + n_val]
85
+ test_idx = idx[n_train + n_val:]
86
+
87
+ for i in train_idx:
88
+ splits["train"].append(all_files[int(i)])
89
+ for i in val_idx:
90
+ splits["val"].append(all_files[int(i)])
91
+ for i in test_idx:
92
+ splits["test"].append(all_files[int(i)])
93
+
94
+ columns = list(attributes.keys())
95
+
96
+ # Write CSVs
97
+ out_paths: Dict[str, Path] = {}
98
+ for s in ("train", "val", "test"):
99
+ out_csv = output_dir / f"{s}.csv"
100
+ out_paths[s] = out_csv
101
+
102
+ with out_csv.open("w", newline="") as f:
103
+ writer = csv.DictWriter(f, fieldnames=columns)
104
+ writer.writeheader()
105
+
106
+ for p in track(
107
+ splits[s],
108
+ description=f"Writing {s}.csv",
109
+ total=len(splits[s])
110
+ ):
111
+
112
+ try:
113
+ row = {}
114
+ for col, fn in attributes.items():
115
+ row[col] = fn(p)
116
+ writer.writerow(row)
117
+ except Exception as e:
118
+ print(
119
+ f"Error at path {p}:\n"
120
+ f"{e}\n"
121
+ f"Skipping..."
122
+ )
123
+
124
+ return out_paths
tria/features.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from audiotools import AudioSignal
3
+
4
+
5
+ ################################################################################
6
+ # Utilities for extracting rhythm feature representations
7
+ ################################################################################
8
+
9
+
10
+ def _moving_average(x: torch.Tensor, window_length: int):
11
+ """
12
+ Smooth features with moving average over frames.
13
+
14
+ Parameters
15
+ ----------
16
+ x : torch.Tensor
17
+ Shape (n_batch, n_feats, n_frames)
18
+ window_length : int
19
+ Smoothing window length
20
+ """
21
+ if window_length <= 1:
22
+ return x
23
+ n_feats = x.shape[1]
24
+ kernel = torch.ones(
25
+ (n_feats, 1, window_length),
26
+ device=x.device, dtype=x.dtype
27
+ ) / window_length
28
+
29
+ pad_left = (window_length - 1) // 2
30
+ pad_right = window_length // 2
31
+ x_pad = torch.nn.functional.pad(x, (pad_left, pad_right), mode="reflect")
32
+
33
+ # Smooth separately over feature channels
34
+ return torch.nn.functional.conv1d(x_pad, kernel, groups=n_feats)
35
+
36
+
37
+ # The 'original' TRIA features can be recovered using:
38
+ # * `slow_ma_ms` = None
39
+ # * `post_smooth_ms` = None
40
+ # * `legacy_normalize` = True
41
+ def rhythm_features(
42
+ signal: AudioSignal,
43
+ sample_rate: int = 44_100,
44
+ n_bands: int = 2,
45
+ n_mels: int = 80,
46
+ window_length: int = 1024,
47
+ hop_length: int = 512,
48
+ normalize_quantile: float = 0.98,
49
+ quantization_levels: int = 33,
50
+ clamp_max: float = 50.0,
51
+ eps: float = 1e-8,
52
+ slow_ma_ms: float = 100.0,
53
+ post_smooth_ms: float = 10.0,
54
+ legacy_normalize: bool = False,
55
+ ):
56
+ """
57
+ Extract multi-band 'rhythm' features from audio by adaptively splitting
58
+ spectrogram along frequency axis and applying normalization, quantization,
59
+ and smoothing / sparsity filtering.
60
+
61
+ Parameters
62
+ ----------
63
+ signal : AudioSignal
64
+ Audio from which to extract features
65
+ sample_rate : int
66
+ Sample rate at which to extract features
67
+ n_bands : int
68
+ Number of frequency bands into which to adaptively divide spectrogram
69
+ n_mels : int
70
+ Number of base mel frequency bins in spectrogram
71
+ window_length : int
72
+ Spectrogram window length
73
+ hop_length : int
74
+ Spectrogram hop length
75
+ normalize_quantile : float
76
+ Optionally normalize each band relative to top-p largest magnitude
77
+ rather than absolute max
78
+ quantization_levels : int
79
+ Number of bins into which feature magnitudes are quantized
80
+ clamp_max : float
81
+ Maximum allowed spectrogram magnitude
82
+ eps : float
83
+ For numerical stability
84
+ slow_ma_ms : float
85
+ Smoothing filter length in milliseconds for transient emphasis (smoothed
86
+ features are subtracted)
87
+ post_smooth_ms : float
88
+ Smoothing filter length in milliseconds for transient smoothing
89
+ legacy_normalize : bool
90
+ If `True`, use mean/std and sigmoid normalization as described in
91
+ original TRIA paper
92
+ """
93
+
94
+ assert n_bands >= 1
95
+ assert quantization_levels >= 2
96
+
97
+ # Loudness normalization
98
+ signal = signal.clone().to_mono().resample(sample_rate).normalize(-16.)
99
+ signal.ensure_max_of_audio()
100
+
101
+ # Clamped mel spectrogram
102
+ mel = signal.mel_spectrogram(
103
+ n_mels=n_mels,
104
+ hop_length=hop_length,
105
+ window_length=window_length,
106
+ ).mean(1) # (n_batch, n_mels, n_frames)
107
+ mel = torch.clamp(mel, 0.0, clamp_max)
108
+
109
+ n_batch, _, n_frames = mel.shape
110
+
111
+ if legacy_normalize:
112
+ # Original normalization: divide by number of mels
113
+ mel = mel / n_mels
114
+ else:
115
+ # Compress logarithmically
116
+ mel = torch.log1p(mel) / torch.log1p(torch.tensor(clamp_max, device=mel.device, dtype=mel.dtype))
117
+
118
+ # Split spectrogram into bands adaptively
119
+ energy_per_bin = mel.mean(dim=-1) # (n_batch, n_mels)
120
+ cum = energy_per_bin.cumsum(dim=1) # (n_batch, n_mels)
121
+ total = cum[:, -1:] # (n_batch, 1)
122
+
123
+ if n_bands == 1:
124
+ bands = mel.sum(dim=1, keepdim=True) # (n_batch, 1, n_frames)
125
+ else:
126
+ targets = torch.linspace(
127
+ 1.0 / n_bands, (n_bands - 1) / n_bands, n_bands - 1,
128
+ device=mel.device, dtype=mel.dtype
129
+ )[None, :] * total # (n_batch, n_bands-1)
130
+
131
+ edges = torch.searchsorted(cum, targets, right=False) # (n_batch, n_bands-1)
132
+
133
+ cuts = torch.cat(
134
+ [
135
+ torch.zeros(n_batch, 1, dtype=torch.long, device=mel.device),
136
+ edges + 1,
137
+ torch.full((n_batch, 1), mel.size(1), dtype=torch.long, device=mel.device),
138
+ ],
139
+ dim=1
140
+ ) # (n_batch, n_bands+1)
141
+
142
+ prefix = mel.cumsum(dim=1) # (n_batch, n_mels, n_frames)
143
+ prefix_pad = torch.cat(
144
+ [torch.zeros(n_batch, 1, n_frames, device=mel.device, dtype=mel.dtype), prefix],
145
+ dim=1
146
+ )
147
+
148
+ a_idx = cuts[:, :-1].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
149
+ b_idx = cuts[:, 1: ].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
150
+ bands = prefix_pad.gather(1, b_idx) - prefix_pad.gather(1, a_idx) # (n_batch, n_bands, n_frames)
151
+
152
+ # Emphasize transients by subtracting smoothed features
153
+ transient = bands.clone()
154
+ to_frames = lambda ms: max(1, int(round((ms / 1000.0) * sample_rate / hop_length)))
155
+
156
+ if slow_ma_ms is not None:
157
+ slow_win = to_frames(slow_ma_ms)
158
+ bands_slow = _moving_average(bands, slow_win) # (n_batch, n_bands, n_frames)
159
+ transient = torch.relu(bands - bands_slow)
160
+
161
+ # Apply additional smoothing to transients
162
+ if post_smooth_ms is not None:
163
+ ps_win = to_frames(post_smooth_ms)
164
+ if ps_win > 1:
165
+ transient = _moving_average(transient, ps_win)
166
+
167
+ # Normalize features across time per band
168
+ if legacy_normalize:
169
+ # Original normalization (mean/std with sigmoid compression)
170
+ mean = transient.mean(dim=-1, keepdim=True)
171
+ std = transient.std(dim=-1, keepdim=True).clamp_min(eps)
172
+ transient = torch.sigmoid((transient - mean) / std)
173
+
174
+ else:
175
+ # Quantile-based normalization
176
+ q = torch.quantile(
177
+ transient.clamp_min(0.0),
178
+ q=normalize_quantile,
179
+ dim=-1,
180
+ keepdim=True
181
+ ).clamp_min(eps)
182
+ transient = (transient / q).clamp(0.0, 1.0)
183
+
184
+ # Quantize feature intensities into bins to ensure a tight information
185
+ # bottleneck
186
+ steps = quantization_levels - 1
187
+ return torch.round(transient * steps) / steps
tria/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tria import TRIA
tria/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (182 Bytes). View file
 
tria/model/__pycache__/mask.cpython-310.pyc ADDED
Binary file (5.87 kB). View file
 
tria/model/__pycache__/sample.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
tria/model/__pycache__/tria.cpython-310.pyc ADDED
Binary file (7.21 kB). View file
 
tria/model/mask.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+ from typing import Union
3
+
4
+ import torch
5
+ from audiotools.core.util import random_state
6
+
7
+ ################################################################################
8
+ # Utilities for masked language modeling
9
+ ################################################################################
10
+
11
+
12
+ def cosine_schedule(t: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Map timestep in [0, 1] to masking ratio in (0, 1] via cosine schedule
15
+ proposed by Chang et al. in "MaskGIT: Masked generative image
16
+ transformer" (2022).
17
+
18
+ Parameters
19
+ ----------
20
+ t : torch.Tensor
21
+ Timestep in [0, 1]
22
+
23
+ Returns
24
+ -------
25
+ torch.Tensor
26
+ Mask proportion in (0, 1]
27
+ """
28
+ return (t * torch.pi / 2).cos().clamp(1e-10, 1.0)
29
+
30
+
31
+ def format_seed(seed):
32
+ if isinstance(seed, (int, float)):
33
+ seed = [seed]
34
+ elif isinstance(seed, torch.Tensor):
35
+ seed = seed.tolist()
36
+ elif isinstance(seed, Iterable):
37
+ pass
38
+ else:
39
+ raise ValueError(f"Invalid random seed of type {type(seed)}")
40
+
41
+ return [random_state(s) for s in seed]
42
+
43
+
44
+ def get_span_mask(
45
+ tokens: torch.Tensor,
46
+ min_prop: float,
47
+ max_prop: float,
48
+ seed: Union[int, Iterable[int]],
49
+ ) -> torch.Tensor:
50
+ """
51
+ Mask a random span of consecutive frames across all codebooks, varying
52
+ across batch.
53
+
54
+ Parameters
55
+ ----------
56
+ tokens : torch.Tensor
57
+ Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
58
+ min_prop : float
59
+ Minimum proportion of frames to mask
60
+ max_prop : float
61
+ Maximum proportion of frames to mask
62
+ seed : Iterable[int]
63
+ One or more random seeds to determine masks
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ Mask of shape (n_batch, n_frames)
69
+ """
70
+ assert min_prop >= 0.0
71
+ assert max_prop <= 1.0
72
+
73
+ n_batch, n_codebooks, n_frames = tokens.shape
74
+
75
+ states = format_seed(seed)
76
+ assert len(states) == n_batch
77
+
78
+ mask = torch.ones(
79
+ n_batch,
80
+ n_frames,
81
+ device=tokens.device,
82
+ dtype=torch.bool,
83
+ ) # (n_batch, n_frames)
84
+
85
+ for i, s in enumerate(states):
86
+ prop = s.uniform(min_prop, max_prop) if min_prop < max_prop else min_prop
87
+
88
+ if prop >= 1.0:
89
+ mask[i] = False
90
+ else:
91
+ span = int(prop * n_frames)
92
+ st = s.randint(0, max(n_frames - span, 1))
93
+ mask[i, st : st + span] = False
94
+
95
+ return mask
96
+
97
+
98
+ def get_current_codebook_mask(
99
+ tokens: torch.Tensor, codebooks: torch.Tensor
100
+ ) -> torch.Tensor:
101
+ """
102
+ Given tokens and batch of selected codebooks, mask all codebooks "above" and
103
+ "below" selected codebooks.
104
+
105
+ Parameters
106
+ ----------
107
+ tokens : torch.Tensor
108
+ Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
109
+ codebooks : torch.Tensor
110
+ Selected codebooks "above" which tokens should be masked, shape
111
+ (n_batch,)
112
+
113
+ Returns
114
+ -------
115
+ torch.Tensor
116
+ Mask of shape (n_batch, n_codebooks)
117
+ """
118
+
119
+ n_batch, n_codebooks, n_frames = tokens.shape
120
+
121
+ assert codebooks.ndim == 1
122
+ assert codebooks.shape[0] in [1, n_batch]
123
+ codebooks = codebooks.repeat(n_batch // codebooks.shape[0])
124
+
125
+ mask = (
126
+ torch.arange(
127
+ n_codebooks,
128
+ dtype=codebooks.dtype,
129
+ device=codebooks.device,
130
+ )[None, :]
131
+ == codebooks[:, None]
132
+ ) # (n_batch, n_codebooks)
133
+
134
+ return mask
135
+
136
+
137
+ def get_next_codebooks_mask(
138
+ tokens: torch.Tensor, codebooks: torch.Tensor
139
+ ) -> torch.Tensor:
140
+ """
141
+ Given tokens and batch of selected codebooks, mask all codebooks "above"
142
+ selected codebooks.
143
+
144
+ Parameters
145
+ ----------
146
+ tokens : torch.Tensor
147
+ Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
148
+ codebooks : torch.Tensor
149
+ Selected codebooks "above" which tokens should be masked, shape
150
+ (n_batch,)
151
+
152
+ Returns
153
+ -------
154
+ torch.Tensor
155
+ Mask of shape (n_batch, n_codebooks)
156
+ """
157
+
158
+ n_batch, n_codebooks, n_frames = tokens.shape
159
+
160
+ assert codebooks.ndim == 1
161
+ assert codebooks.shape[0] in [1, n_batch]
162
+ codebooks = codebooks.repeat(n_batch // codebooks.shape[0])
163
+
164
+ mask = (
165
+ torch.arange(
166
+ n_codebooks,
167
+ dtype=codebooks.dtype,
168
+ device=codebooks.device,
169
+ )[None, :]
170
+ <= codebooks[:, None]
171
+ ) # (n_batch, n_codebooks)
172
+
173
+ return mask
174
+
175
+
176
+ def get_random_mask(
177
+ tokens: torch.Tensor,
178
+ prop: Union[float, Iterable[float]],
179
+ seed: Union[int, Iterable[int]],
180
+ ) -> torch.Tensor:
181
+ """
182
+ Parameters
183
+ ----------
184
+ tokens : torch.Tensor
185
+ Tokens to be masked, shape (n_batch, n_codebooks, n_frames)
186
+ prop : torch.Tensor
187
+ Proportion of tokens to be masked, shape (n_batch,)
188
+ seed : Iterable[int]
189
+ One or more random seeds to determine masks
190
+
191
+ Returns
192
+ -------
193
+ torch.Tensor
194
+ Random mask of shape (n_batch, n_codebooks, n_frames)
195
+ """
196
+ n_batch, n_codebooks, n_frames = tokens.shape
197
+
198
+ if isinstance(prop, torch.Tensor):
199
+ prop = prop.tolist()
200
+ assert len(prop) == n_batch
201
+
202
+ states = format_seed(seed)
203
+ assert len(states) == n_batch
204
+
205
+ mask = torch.ones(
206
+ n_batch,
207
+ n_codebooks,
208
+ n_frames,
209
+ device=tokens.device,
210
+ dtype=torch.bool,
211
+ ) # (n_batch, n_codebooks, n_frames)
212
+
213
+ for i, (s, p) in enumerate(zip(states, prop)):
214
+ mask[i] = torch.from_numpy(s.rand(n_codebooks, n_frames)).to(mask.device) > p
215
+
216
+ return mask
217
+
218
+
219
+ def combine_masks(
220
+ mask_span: torch.Tensor,
221
+ mask_current_codebook: torch.Tensor,
222
+ mask_next_codebooks: torch.Tensor,
223
+ mask_random: torch.Tensor,
224
+ leak: bool = False,
225
+ ) -> torch.Tensor:
226
+ """
227
+ Combine sampled masks to allow for application to token buffer.
228
+
229
+ Parameters
230
+ ----------
231
+ mask_span : torch.Tensor
232
+ Shape (n_batch, n_frames)
233
+ mask_current_codebook : torch.Tensor
234
+ Shape (n_batch, n_codebooks)
235
+ mask_next_codebooks : torch.Tensor
236
+ Shape (n_batch, n_codebooks)
237
+ mask_random : torch.Tensor
238
+ Shape (n_batch, n_codebooks, n_frames)
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ Combined mask, shape (n_batch, n_codebooks, n_frames)
244
+ torch.Tensor
245
+ """
246
+
247
+ mask_current_level = mask_current_codebook[:, :, None] & (~mask_random)
248
+
249
+ if leak:
250
+ # Allow leakage from "higher" codebooks inside masked span
251
+ higher = (~mask_next_codebooks[:, :, None]) & (~mask_random)
252
+ else:
253
+ # Strictly mask "higher" codebooks inside masked span
254
+ higher = ~mask_next_codebooks[:, :, None]
255
+
256
+ # Inside span, unmask everything except "higher" codebooks and masked
257
+ # positions in current codebook
258
+ mask = ~(higher | mask_current_level)
259
+
260
+ # Outside span, fully unmask
261
+ mask = mask | mask_span[:, None, :]
262
+
263
+ return mask
tria/model/sample.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Iterable, Union, Optional
6
+ import numpy as np
7
+ from numpy.random import RandomState
8
+
9
+ from .mask import cosine_schedule, format_seed
10
+
11
+ ################################################################################
12
+ # Utilities for sampling from trained TRIA model
13
+ ################################################################################
14
+
15
+
16
+ def top_p_top_k(
17
+ logits: torch.Tensor,
18
+ top_p: float = None,
19
+ top_k: int = None,
20
+ ):
21
+ """
22
+ Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
23
+ Garcia. See: https://github.com/hugofloresgarcia/vampnet/
24
+
25
+ Parameters
26
+ ----------
27
+ logits : torch.Tensor
28
+ Shape (..., n_classes)
29
+ """
30
+ logits = logits.clone()
31
+ n_classes = logits.shape[-1]
32
+
33
+ # Mask logits outside top-k by setting to -inf
34
+ if top_k is not None and 0 < top_k < n_classes:
35
+ thresh = logits.topk(top_k, dim=-1).values[..., -1:] # (..., 1)
36
+ logits[logits < thresh] = float("-inf")
37
+
38
+ # Mask logits outside top-p by setting to -inf
39
+ if top_p is not None and 0.0 < top_p < 1.0:
40
+ # Sort descending
41
+ sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) # (..., n_classes)
42
+ sorted_probs = F.softmax(sorted_logits, dim=-1) # (..., n_classes)
43
+ cumsum = sorted_probs.cumsum(dim=-1) # (..., n_classes)
44
+
45
+ # Keep at least one logit
46
+ to_remove = cumsum > top_p
47
+ to_remove[..., 0] = False
48
+ remove_idx = torch.zeros_like(to_remove).scatter(-1, sorted_idx, to_remove)
49
+ logits[remove_idx] = float("-inf")
50
+
51
+ return logits
52
+
53
+
54
+ def sample(
55
+ logits: torch.Tensor,
56
+ temp: float,
57
+ argmax: bool = False,
58
+ ):
59
+ """
60
+ Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
61
+ Garcia. See: https://github.com/hugofloresgarcia/vampnet/
62
+
63
+ Parameters
64
+ ----------
65
+ logits : torch.Tensor
66
+ Shape (..., n_classes)
67
+
68
+ Returns
69
+ -------
70
+ torch.Tensor
71
+ Sampled tokens, shape of `logits` with trailing `n_classes` dimension
72
+ removed
73
+ torch.Tensor
74
+ Probabilities of sampled tokens, shape of `logits` with trailing
75
+ `n_classes` dimension removed
76
+ """
77
+ if temp <= 0:
78
+ argmax = True
79
+ temp = 1.0
80
+
81
+ if argmax:
82
+ sampled = logits.argmax(dim=-1)
83
+ probs = F.softmax(
84
+ logits, dim=-1
85
+ ).take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
86
+ return sampled, probs
87
+
88
+ probs = F.softmax(logits / temp, dim=-1)
89
+ flat = probs.reshape(-1, probs.shape[-1])
90
+ draws = torch.multinomial(flat, 1).squeeze(-1)
91
+ sampled = draws.view(*probs.shape[:-1])
92
+ chosen = probs.take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
93
+ return sampled, chosen
94
+
95
+
96
+ def mask_by_confidence(
97
+ probs: torch.Tensor,
98
+ n: torch.Tensor,
99
+ temp: float,
100
+ causal_bias: float,
101
+ state: Iterable[RandomState],
102
+ eligible: Optional[torch.Tensor] = None,
103
+ ):
104
+ """
105
+ Re-mask predicted tokens in a single codebook such that `n` previously-
106
+ masked tokens are left unmasked, using confidence (probability assigned to
107
+ tokens during sampling) to select which tokens remain. This confidence can
108
+ be mediated by random noise and a bias to unmask early (leftward) positions
109
+ first.
110
+
111
+ Parameters
112
+ ----------
113
+ probs : torch.Tensor
114
+ Probabilities assigned to sampled tokens, shape (n_batch, n_frames)
115
+ n : torch.Tensor
116
+ Target number of unmasked tokens, shape (n_batch,)
117
+ temp : float
118
+ Mask temperature, corresponding to randomness in unmasking process
119
+ causal_bias : float
120
+ Bias towards unmasking early (leftward) token positions first; typically
121
+ in (0, 1]. Note that large values of `temp` can effectively "wash out"
122
+ this causal bias
123
+ state : Iterable[RandomState]
124
+ Random seeds for reproducibility
125
+ eligible : torch.Tensor
126
+ Optional indicator for positions eligible for unmasking, shape (n_batch, n_frames)
127
+
128
+ """
129
+
130
+ n_batch, n_frames = probs.shape
131
+ device = probs.device
132
+
133
+ if eligible is None:
134
+ eligible = torch.isfinite(probs) & (probs > 0)
135
+ else:
136
+ eligible = eligible.to(torch.bool)
137
+
138
+ # Masked token count and target
139
+ n_masked = eligible.long().sum(dim=-1)
140
+ n_unmask = (n_masked - n).clamp_min(0)
141
+
142
+ # Gumbel noise to introduce randomness into unmasking
143
+ u = torch.stack([
144
+ torch.from_numpy(s.uniform(1e-6, 1 - 1e-6, n_frames)) for s in state
145
+ ], dim=0).to(probs)
146
+ gumbel = -torch.log(-torch.log(u))
147
+
148
+ # Log-confidences + noise
149
+ s = probs.clamp_min(1e-12)
150
+ confs = torch.log(s) + temp * gumbel
151
+
152
+ # Optional causal bias in log-domain
153
+ if causal_bias > 0:
154
+ frame_relpos = (1 - (torch.arange(n_frames, device=device, dtype=confs.dtype) + 1) / n_frames).view(1, -1)
155
+ confs = confs + causal_bias * frame_relpos
156
+
157
+ # Only eligible positions can be chosen
158
+ confs_masked = confs.masked_fill(~eligible, float("-inf"))
159
+ sorted_vals, sorted_idx = confs_masked.sort(dim=-1, descending=True)
160
+ rank = torch.arange(n_frames, device=device).view(1, n_frames).expand_as(confs_masked)
161
+ k = n_unmask.view(n_batch, 1)
162
+ pick_sorted = rank < k
163
+ pick = torch.zeros_like(pick_sorted, dtype=torch.bool).scatter(-1, sorted_idx, pick_sorted)
164
+
165
+ # Return tokens_mask semantics (True = unmasked/keep)
166
+ mask = ~(eligible & (~pick))
167
+ return mask
168
+
tria/model/tria.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Optional, Union, Iterable
4
+
5
+ from ..nn.transformer import Transformer
6
+ from .mask import cosine_schedule, format_seed
7
+ from .sample import mask_by_confidence, top_p_top_k, sample
8
+
9
+ ################################################################################
10
+ # TRIA masked language model
11
+ ################################################################################
12
+
13
+
14
+ class TRIA(torch.nn.Module):
15
+
16
+ def __init__(
17
+ self,
18
+ codebook_size: int = 1024,
19
+ n_codebooks: int = 9,
20
+ n_feats: int = 2,
21
+ n_channels: int = 512,
22
+ n_heads: int = 8,
23
+ n_layers: int = 12,
24
+ mult: int = 4,
25
+ p_dropout: float = 0.0,
26
+ p_token_dropout: float = 0.0,
27
+ bias: bool = False,
28
+ max_len: int = 8192,
29
+ pos_enc: Optional[str] = "rope",
30
+ qk_norm: bool = True,
31
+ use_sdpa: bool = True,
32
+ interp: str = "nearest",
33
+ share_emb: bool = True,
34
+ ):
35
+ super().__init__()
36
+
37
+ assert interp in ["nearest", "linear"]
38
+
39
+ self.adapter = torch.nn.Linear(n_feats, n_channels, bias=bias)
40
+ self.in_proj = torch.nn.Linear(2 * n_channels, n_channels, bias=bias)
41
+
42
+ self.backbone = Transformer(
43
+ n_channels=n_channels,
44
+ n_heads=n_heads,
45
+ n_layers=n_layers,
46
+ mult=mult,
47
+ p_dropout=p_dropout,
48
+ bias=False,
49
+ max_len=max_len,
50
+ pos_enc_self_attn=pos_enc,
51
+ qk_norm=qk_norm,
52
+ use_sdpa=use_sdpa,
53
+ )
54
+
55
+ self.tokens_emb = torch.nn.Embedding(codebook_size * n_codebooks, n_channels)
56
+ self.head = torch.nn.Linear(n_channels, codebook_size * n_codebooks, bias=False) # No bias on head, to allow weight-sharing
57
+ if share_emb:
58
+ self.tokens_emb.weight = self.head.weight
59
+
60
+ # Masked token embedding
61
+ self.tokens_mask_emb = torch.nn.Parameter(torch.zeros(n_channels))
62
+
63
+ # Attributes
64
+ self.p_token_dropout = p_token_dropout
65
+ self.codebook_size = codebook_size
66
+ self.n_codebooks = n_codebooks
67
+ self.n_feats = n_feats
68
+ self.n_channels = n_channels
69
+ self.n_layers = n_layers
70
+ self.interp = interp
71
+
72
+ def forward(
73
+ self,
74
+ tokens: torch.Tensor,
75
+ feats: torch.Tensor,
76
+ codebook: torch.Tensor,
77
+ tokens_mask: torch.Tensor,
78
+ feats_mask: torch.Tensor,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Parameters
82
+ ----------
83
+ tokens : torch.Tensor
84
+ Acoustic tokens, fully or partially masked; shape
85
+ (n_batch, n_codebooks, n_frames)
86
+ feats : torch.Tensor
87
+ Aligned features to guide generation; shape (n_batch, n_feats, n_frames)
88
+ codebook : torch.Tensor
89
+ Codebook in which to predict masked tokens; shape (n_batch,)
90
+ tokens_mask : torch.Tensor
91
+ Boolean tensor indicating umasked token positions (True where
92
+ unmasked, False where masked); shape (n_batch, n_codebooks, n_frames)
93
+ feats_mask : torch.Tensor
94
+ """
95
+
96
+ assert tokens.ndim == 3 # (n_batch, n_codebooks, n_frames)
97
+ assert feats.ndim == 3 # (n_batch, n_feats, n_frames')
98
+ assert tokens_mask.ndim == 3 # (n_batch, n_codebooks, n_frames)
99
+ assert feats_mask.ndim == 2 # (n_batch, n_frames')
100
+ assert tokens.shape[1] == self.n_codebooks
101
+
102
+ n_batch, n_codebooks, n_frames = tokens.shape
103
+
104
+ # Interpolate features and mask to tokens resulution
105
+ feats = torch.nn.functional.interpolate(feats, n_frames, mode=self.interp)
106
+ feats_mask = torch.nn.functional.interpolate(
107
+ feats_mask[:, None, :].float(), n_frames, mode="nearest")
108
+
109
+ # Adapt features
110
+ feats = self.adapter(feats.transpose(1, 2)) # (n_batch, n_frames, n_channels)
111
+
112
+ # Embed tokens
113
+ codebook_offsets = torch.arange(
114
+ n_codebooks, dtype=tokens.dtype, device=tokens.device
115
+ ).reshape(1, -1, 1) * self.codebook_size # (1, n_codebooks, 1)
116
+ tokens = tokens + codebook_offsets # (n_batch, n_codebooks, n_frames)
117
+ tokens_emb = self.tokens_emb(tokens) # (n_batch, n_codebooks, n_frames, n_channels)
118
+
119
+ # Zero masked token embeddings
120
+ tokens_emb = tokens_emb * tokens_mask.unsqueeze(-1).float()
121
+
122
+ # Apply learned embedding to masked token positions in current codebook
123
+ mask_pos = torch.arange(
124
+ n_codebooks, dtype=tokens.dtype, device=tokens.device
125
+ )[None, :] == codebook[:, None] # (n_batch, n_codebooks)
126
+ mask_pos = torch.logical_and(mask_pos.unsqueeze(-1), ~tokens_mask) # (n_batch, n_codebooks, n_frames)
127
+
128
+ tokens_emb = tokens_emb + (
129
+ mask_pos.unsqueeze(-1).float()
130
+ ) * self.tokens_mask_emb.reshape(1, 1, 1, -1) # (n_batch, n_codebooks, n_frames, n_channels)
131
+
132
+ # Token dropout (encourage attention to unmasked frames)
133
+ if self.training and self.p_token_dropout > 0.0:
134
+
135
+ # Apply dropout within masked frames and "below" current codebook
136
+ below = torch.arange(
137
+ n_codebooks, device=tokens.device
138
+ )[None, :, None] < codebook[:, None, None] # (n_batch, n_codebooks, 1)
139
+ eligible = below & feats_mask.bool() # (n_batch, n_codebooks, n_frames)
140
+ drop = (
141
+ torch.rand(
142
+ n_batch, 1, n_frames, 1, device=tokens.device
143
+ ) < self.p_token_dropout) & eligible[..., None]
144
+ tokens_emb = tokens_emb.masked_fill(drop, 0.0)
145
+
146
+ # Zero "ignored" features
147
+ feats = feats * feats_mask.transpose(1, 2)
148
+
149
+ # Sum embedded tokens across codebooks
150
+ tokens_emb = tokens_emb.sum(dim=1) # (n_batch, n_frames, n_channels)
151
+
152
+ # Sum embedded tokens and adapted features
153
+ x = torch.cat([feats, tokens_emb], dim=-1) # (n_batch, n_frames, 2 * n_channels)
154
+ x = self.in_proj(x) # (n_batch, n_frames, n_channels)
155
+
156
+ # Process with transformer
157
+ x = self.backbone(x=x) # (n_batch, n_frames, n_channels)
158
+
159
+ # Predict token logits
160
+ logits = self.head(x) # (n_batch, n_frames, n_codebooks * codebook_size)
161
+ logits = logits.reshape(
162
+ n_batch, n_frames, n_codebooks, self.codebook_size
163
+ ).permute(0, 2, 1, 3) # (n_batch, n_codebooks, n_frames, codebook_size)
164
+
165
+ return logits
166
+
167
+ @torch.inference_mode()
168
+ def inference(
169
+ self,
170
+ tokens: torch.Tensor,
171
+ feats: torch.Tensor,
172
+ tokens_mask: torch.Tensor,
173
+ feats_mask: torch.Tensor,
174
+ top_p: Union[float, Iterable[float]] = 1.0,
175
+ top_k: Union[int, Iterable[int]] = None,
176
+ temp: Union[float, Iterable[float]] = 1.0,
177
+ mask_temp: Union[float, Iterable[float]] = 10.5,
178
+ iterations: Union[int, Iterable[int]] = 8,
179
+ guidance_scale: Union[float, Iterable[float]] = None,
180
+ causal_bias: Union[float, Iterable[float]] = None,
181
+ seed: Union[int, Iterable[int]] = None,
182
+ ):
183
+
184
+ assert not self.training
185
+ device = next(iter(self.parameters())).device
186
+
187
+ # Avoid overwriting
188
+ tokens = tokens.clone().to(device)
189
+ tokens_mask = tokens_mask.clone().to(device)
190
+
191
+ assert tokens.ndim == 3
192
+ n_batch, n_codebooks, n_frames = tokens.shape
193
+
194
+ assert feats.ndim == 3
195
+ _, n_feats, _ = feats.shape
196
+
197
+ assert n_codebooks == self.n_codebooks
198
+ assert n_feats == self.n_feats
199
+
200
+ # Interpolate features to token resolution
201
+ feats = torch.nn.functional.interpolate(
202
+ feats.to(device), n_frames, mode=self.interp,
203
+ )
204
+ feats_mask = torch.nn.functional.interpolate(
205
+ feats_mask.unsqueeze(1).float().to(device), n_frames, mode="nearest",
206
+ ).squeeze(1).to(feats_mask.dtype)
207
+
208
+ # Account for per-codebook args
209
+ def _to_codebooks(v):
210
+ if isinstance(v, torch.Tensor):
211
+ v = v.tolist()
212
+ elif isinstance(v, Iterable):
213
+ pass
214
+ else:
215
+ v = [v]
216
+
217
+ if len(v) == n_codebooks:
218
+ return v
219
+ elif len(v) == 1:
220
+ return v * n_codebooks
221
+ else:
222
+ raise ValueError(
223
+ f"Sampling parameters must be scalars, "
224
+ f"length-1 iterable, or length-n_codebooks ({n_codebooks})"
225
+ )
226
+
227
+ # Construct `n_codebooks` state lists of length `n_batch` each
228
+ seed = seed or 0
229
+ if not isinstance(seed, Iterable):
230
+ seed = [seed]
231
+ assert len(seed) in [1, n_batch]
232
+ seed = seed * (n_batch // len(seed))
233
+ state = [format_seed([s + 10007 * cb for s in seed]) for cb in range(n_codebooks)]
234
+
235
+ top_p, top_k = _to_codebooks(top_p), _to_codebooks(top_k)
236
+ temp, mask_temp = _to_codebooks(temp), _to_codebooks(mask_temp)
237
+ iterations = _to_codebooks(iterations)
238
+ guidance_scale = _to_codebooks(guidance_scale)
239
+ causal_bias = _to_codebooks(causal_bias)
240
+
241
+ # Track initial masked token counts
242
+ n_masked_init = (~tokens_mask).long().sum(dim=-1) # (n_batch, n_codebooks)
243
+
244
+ # Generate one codebook at a time
245
+ for codebook_idx, (
246
+ _state, _top_p, _top_k, _temp, _mask_temp,
247
+ _iterations, _guidance_scale, _causal_bias,
248
+ ) in enumerate(zip(
249
+ state, top_p, top_k, temp, mask_temp,
250
+ iterations, guidance_scale, causal_bias,
251
+ )):
252
+ _causal_bias = _causal_bias or 0.
253
+ assert 0. <= _causal_bias
254
+
255
+ _temp = _temp or 1.0
256
+ assert 0. < _temp
257
+
258
+ _mask_temp = _mask_temp or 0.0
259
+ assert 0. <= _mask_temp
260
+
261
+ _iterations = max(_iterations or 1, 1)
262
+
263
+ for _iter in range(_iterations):
264
+
265
+ # CFG on features by masking
266
+ if _guidance_scale:
267
+ tokens_cfg = torch.cat([tokens, tokens], dim=0)
268
+ tokens_mask_cfg = torch.cat([tokens_mask, tokens_mask], dim=0)
269
+
270
+ feats_cfg = torch.cat([feats, feats], dim=0)
271
+ feats_mask_cfg = torch.cat([feats_mask, torch.zeros_like(feats_mask)], dim=0)
272
+
273
+ logits_cond, logits_uncond = self.forward(
274
+ tokens_cfg,
275
+ feats_cfg,
276
+ torch.full(
277
+ (tokens_cfg.shape[0],),
278
+ codebook_idx,
279
+ dtype=torch.long,
280
+ device=device,
281
+ ),
282
+ tokens_mask_cfg,
283
+ feats_mask_cfg,
284
+ ).chunk(2, dim=0) # (n_batch, n_codebooks, n_frames, codebook_size) x2
285
+
286
+ logits = logits_uncond + _guidance_scale * (logits_cond - logits_uncond) # (n_batch, n_codebooks, n_frames, codebook_size)
287
+
288
+ else:
289
+ logits = self.forward(
290
+ tokens,
291
+ feats,
292
+ torch.full(
293
+ (tokens.shape[0],),
294
+ codebook_idx,
295
+ dtype=torch.long,
296
+ device=device,
297
+ ),
298
+ tokens_mask,
299
+ feats_mask,
300
+ ) # (n_batch, n_codebooks, n_frames, codebook_size)
301
+
302
+ # Truncate logits and sample tokens at masked positions
303
+ logits = top_p_top_k(
304
+ logits[:, codebook_idx:codebook_idx+1, ...], _top_p, _top_k
305
+ ) # (n_batch, 1, n_frames, codebook_size)
306
+ sampled, probs = sample(
307
+ logits, _temp, argmax=(_iter==_iterations-1),
308
+ ) # (n_batch, 1, n_frames) x2
309
+ write_idx = ~(tokens_mask[:, codebook_idx, :]) # (n_batch, n_frames)
310
+ tokens[:, codebook_idx, :][write_idx] = sampled[:, 0, :][write_idx]
311
+
312
+ # Compute implied generation timestep and corresponding target mask
313
+ # ratio
314
+ t = (_iter + 1) / _iterations
315
+ tgt_p_mask = cosine_schedule(torch.tensor([t]*n_batch, device=device)) # (n_batch,)
316
+
317
+ # Compute target and actual number of masked positions in current
318
+ # codebook
319
+ tgt_n_masked = torch.floor(tgt_p_mask * n_masked_init[:, codebook_idx]).long() # (n_batch,)
320
+ n_masked = write_idx.long().sum(dim=-1) # (n_batch,)
321
+
322
+ # Do not complete unmasking until final iteration, i.e. always leave at
323
+ # least one token unmasked
324
+ if _iter < _iterations - 1:
325
+ tgt_n_masked = torch.minimum(n_masked - 1, tgt_n_masked).clamp_min(1)
326
+
327
+ # Select which tokens to unmask via confidence (assigned probability),
328
+ # mediated by causal bias and random noise
329
+ _probs = torch.full_like(probs[:, 0, :], torch.inf) # (n_batch, n_frames)
330
+ _probs[write_idx] = probs[:, 0, :][write_idx]
331
+ tokens_mask[:, codebook_idx, :] = mask_by_confidence(
332
+ probs=_probs,
333
+ n=tgt_n_masked,
334
+ temp=_mask_temp * (1 - t), # Mask temperature annealing
335
+ causal_bias=_causal_bias or 0.0,
336
+ state=_state,
337
+ eligible=write_idx,
338
+ )
339
+
340
+ # Re-apply span and codebook masks
341
+ tokens_mask = ~torch.logical_and(~tokens_mask, feats_mask.unsqueeze(1))
342
+ tokens_mask[:, :codebook_idx, :] = True
343
+
344
+ return tokens
tria/nn/__init__.py ADDED
File without changes
tria/nn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
tria/nn/__pycache__/attention.cpython-310.pyc ADDED
Binary file (6.67 kB). View file
 
tria/nn/__pycache__/norm.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
tria/nn/__pycache__/pos_enc.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
tria/nn/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.7 kB). View file
 
tria/nn/attention.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .norm import QKNorm
10
+ from .pos_enc import apply_rope
11
+ from .pos_enc import apply_sinusoidal
12
+ from .pos_enc import build_rope_cache
13
+ from .pos_enc import build_sinusoidal_cache
14
+
15
+ ################################################################################
16
+ # Multihead attention operation
17
+ ################################################################################
18
+
19
+
20
+ def ensure_masks(
21
+ n_batch: int,
22
+ seq_len_q: int,
23
+ seq_len_k: int,
24
+ device,
25
+ mask_q: Optional[torch.Tensor],
26
+ mask_k: Optional[torch.Tensor],
27
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
28
+ """
29
+ Parameters
30
+ ----------
31
+ n_batch : int
32
+ seq_len_q : int
33
+ seq_len_k : int
34
+ mask_q : torch.Tensor
35
+ Shape (n_batch, seq_len_q)
36
+ mask_k : torch.Tensor
37
+ Shape (n_batch, seq_len_k)
38
+ """
39
+ if mask_q is None:
40
+ mask_q = torch.ones(n_batch, seq_len_q, dtype=torch.bool, device=device)
41
+ if mask_k is None:
42
+ mask_k = torch.ones(n_batch, seq_len_k, dtype=torch.bool, device=device)
43
+ return mask_q, mask_k
44
+
45
+
46
+ def make_attn_mask(
47
+ mask_q: torch.Tensor,
48
+ mask_k: torch.Tensor,
49
+ dtype,
50
+ ) -> torch.Tensor:
51
+ """
52
+ Use "key padding mask" convention to prevent empty rows in attention score
53
+ matrix (and thus softmax issues).
54
+
55
+ Parameters
56
+ ----------
57
+ mask_q : torch.Tensor
58
+ Query sequence mask, shape (n_batch, seq_len_q)
59
+ mask_k : torch.Tensor
60
+ Key sequence mask, shape (n_batch, seq_len_k)
61
+
62
+ Returns
63
+ -------
64
+ torch.Tensor
65
+ Additive attention mask for scaled_dot_product_attention, shape
66
+ (n_batch, 1, seq_len_q, seq_len_k)
67
+ """
68
+ n_batch, seq_len_q = mask_q.shape
69
+ seq_len_k = mask_k.shape[1]
70
+
71
+ exclude = (
72
+ (~mask_k)[:, None, :].expand(n_batch, seq_len_q, seq_len_k).unsqueeze(1)
73
+ ) # (n_batch, 1, seq_len_q, seq_len_k)
74
+ mask = exclude.to(dtype=dtype).masked_fill(exclude, float("-inf"))
75
+
76
+ return mask # (n_batch, 1, seq_len_q, seq_len_k)
77
+
78
+
79
+ def sdpa_with_fallback(
80
+ q: torch.Tensor,
81
+ k: torch.Tensor,
82
+ v: torch.Tensor,
83
+ attn_mask: Optional[torch.Tensor],
84
+ p_dropout: float,
85
+ training: bool,
86
+ use_sdpa: bool = True,
87
+ ) -> torch.Tensor:
88
+ """
89
+ Optionally use PyTorch scaled_dot_product_attention (SDPA), which picks
90
+ efficient attention implementations (e.g. flash attention) if available
91
+
92
+ Parameters
93
+ ----------
94
+ q : torch.Tensor
95
+ Query, shape (n_batch, n_heads, seq_len_q, head_channels)
96
+ k : torch.Tensor
97
+ Key, shape (n_batch, n_heads, seq_len_k, head_channels)
98
+ v : torch.Tensor
99
+ Value, shape (n_batch, n_heads, seq_len_k, head_channels)
100
+ attn_mask : torch.Tensor
101
+ Additive attention mask (0 or -inf), shape (n_batch, 1, seq_len_q, seq_len_k)
102
+
103
+ Returns
104
+ -------
105
+ torch.Tensor
106
+ Shape (n_batch, n_heads, seq_len_q, head_channels)
107
+ """
108
+
109
+ n_batch, n_heads, seq_len_q, head_channels = q.shape
110
+ seq_len_k = k.shape[2]
111
+
112
+ if use_sdpa and q.is_cuda:
113
+ if attn_mask is not None and (
114
+ (attn_mask.dtype == torch.bool and attn_mask.all())
115
+ or (attn_mask.dtype != torch.bool and not attn_mask.ne(0).any())
116
+ ):
117
+ attn_mask = None
118
+
119
+ out = F.scaled_dot_product_attention(
120
+ q,
121
+ k,
122
+ v,
123
+ attn_mask=attn_mask,
124
+ dropout_p=p_dropout if training else 0.0,
125
+ is_causal=False,
126
+ )
127
+ return out
128
+
129
+ # Fallback
130
+ scale = 1.0 / math.sqrt(head_channels)
131
+ scores = torch.einsum("bhtd,bhsd->bhts", q, k) * scale
132
+ if attn_mask is not None:
133
+ scores = scores + attn_mask # Additive mask
134
+ attn = scores.softmax(dim=-1)
135
+ if training and p_dropout > 0.0:
136
+ attn = F.dropout(attn, p=p_dropout)
137
+ out = torch.einsum("bhts,bhsd->bhtd", attn, v)
138
+ return out
139
+
140
+
141
+ class MultiheadAttention(nn.Module):
142
+ def __init__(
143
+ self,
144
+ n_channels: int,
145
+ n_heads: int,
146
+ p_dropout: float = 0.0,
147
+ bias: bool = True,
148
+ max_len: int = 8192,
149
+ pos_enc: Optional[str] = "rope",
150
+ qk_norm: bool = True,
151
+ use_sdpa: bool = True,
152
+ ):
153
+ super().__init__()
154
+ assert n_channels % n_heads == 0, "`n_channels` must be divisible by `n_heads`"
155
+ assert pos_enc in ("rope", "absolute", "none", None)
156
+
157
+ self.n_channels = n_channels
158
+ self.n_heads = n_heads
159
+ self.head_channels = n_channels // n_heads
160
+ self.p_dropout = p_dropout
161
+ self.pos_enc = pos_enc
162
+ self.max_len = max_len
163
+ self.use_sdpa = use_sdpa
164
+
165
+ self.q_proj = nn.Linear(n_channels, n_channels, bias=bias)
166
+ self.k_proj = nn.Linear(n_channels, n_channels, bias=bias)
167
+ self.v_proj = nn.Linear(n_channels, n_channels, bias=bias)
168
+ self.o_proj = nn.Linear(n_channels, n_channels, bias=bias)
169
+
170
+ self.o_dropout = nn.Dropout(p_dropout)
171
+
172
+ self.qk_norm = QKNorm(self.head_channels) if qk_norm else None
173
+ self.pos_cache = None
174
+
175
+ def _maybe_build_pos_cache(self, device, dtype):
176
+ if self.pos_enc in [None, "none"] or self.pos_cache is not None:
177
+ return
178
+ if self.pos_enc == "absolute":
179
+ self.pos_cache = build_sinusoidal_cache(
180
+ self.max_len, self.head_channels, device, dtype=torch.float32
181
+ )
182
+ elif self.pos_enc == "rope":
183
+ cos, sin = build_rope_cache(
184
+ self.max_len, self.head_channels, device, dtype=torch.float32
185
+ )
186
+ self.pos_cache = (cos, sin)
187
+
188
+ def forward(
189
+ self,
190
+ q: torch.Tensor,
191
+ k: torch.Tensor,
192
+ v: torch.Tensor,
193
+ mask_q: Optional[torch.Tensor] = None,
194
+ mask_k: Optional[torch.Tensor] = None,
195
+ attn_mask: Optional[torch.Tensor] = None,
196
+ ) -> torch.Tensor:
197
+ """
198
+ Parameters
199
+ ----------
200
+ q : torch.Tensor
201
+ Query, shape (n_batch, seq_len_q, n_channels)
202
+ k : torch.Tensor
203
+ Key, shape (n_batch, seq_len_k, n_channels)
204
+ v : torch.Tensor
205
+ Value, shape (n_batch, seq_len_k, n_channels)
206
+ mask_q : torch.Tensor
207
+ Boolean mask, `True` for valid positions; shape (n_batch, seq_len_q)
208
+ mask_k : torch.Tensor
209
+ Boolean mask, `True` for valid positions; shape (n_batch, seq_len_k)
210
+ attn_mask : torch.tensor
211
+ Additive (0, -inf) mask; shape (n_batch, 1, seq_len_q, seq_len_k)
212
+ """
213
+
214
+ n_batch, seq_len_q, _ = q.shape
215
+ seq_len_k = k.shape[1]
216
+ device, dtype = q.device, q.dtype
217
+
218
+ # Projections (n_batch, seq_len, n_channels) -> (n_batch, n_heads, seq_len, head_channels)
219
+ q = (
220
+ self.q_proj(q)
221
+ .view(n_batch, seq_len_q, self.n_heads, self.head_channels)
222
+ .transpose(1, 2)
223
+ )
224
+ k = (
225
+ self.k_proj(k)
226
+ .view(n_batch, seq_len_k, self.n_heads, self.head_channels)
227
+ .transpose(1, 2)
228
+ )
229
+ v = (
230
+ self.v_proj(v)
231
+ .view(n_batch, seq_len_k, self.n_heads, self.head_channels)
232
+ .transpose(1, 2)
233
+ )
234
+
235
+ # Positional encoding
236
+ self._maybe_build_pos_cache(device=device, dtype=dtype)
237
+ if self.pos_enc == "absolute":
238
+ cache = self.pos_cache # (max_seq_len, head_channels)
239
+ q = apply_sinusoidal(q, cache)
240
+ k = apply_sinusoidal(k, cache)
241
+ elif self.pos_enc == "rope":
242
+ cos, sin = self.pos_cache # (max_seq_len, head_channels/2)
243
+ q = apply_rope(q, cos, sin)
244
+ k = apply_rope(k, cos, sin)
245
+
246
+ # QK-Norm
247
+ if self.qk_norm is not None:
248
+ q, k = self.qk_norm(q, k)
249
+
250
+ # Masks
251
+ mask_q, mask_k = ensure_masks(
252
+ n_batch, seq_len_q, seq_len_k, device, mask_q, mask_k
253
+ )
254
+ pad_mask = make_attn_mask(
255
+ mask_q, mask_k, dtype
256
+ ) # (n_batch, 1, seq_len_q, seq_len_k)
257
+
258
+ if attn_mask is not None:
259
+ pad_mask = pad_mask + attn_mask
260
+
261
+ # Attention
262
+ y = sdpa_with_fallback(
263
+ q,
264
+ k,
265
+ v,
266
+ attn_mask=pad_mask,
267
+ p_dropout=self.p_dropout,
268
+ training=self.training,
269
+ use_sdpa=self.use_sdpa,
270
+ ) # (n_batch, n_heads, seq_len_q, head_channels)
271
+
272
+ y = y.transpose(1, 2).contiguous().view(n_batch, seq_len_q, self.n_channels)
273
+ y = self.o_proj(y) # (n_batch, seq_len_q, n_channels)
274
+ y = self.o_dropout(y)
275
+
276
+ # Mask outputs
277
+ if mask_q is not None:
278
+ with torch.no_grad():
279
+ y.masked_fill_(~mask_q[:, :, None], 0.0)
280
+ return y
tria/nn/norm.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ ################################################################################
7
+ # Normalization layers
8
+ ################################################################################
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(self, n_channels: int, eps: float = 1e-6):
13
+ super().__init__()
14
+ self.eps = eps
15
+ self.weight = nn.Parameter(torch.ones(n_channels))
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ """
19
+ Normalize over final dimension
20
+ """
21
+ rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
22
+ return self.weight * x * rms # Broadcast targets final dimension
23
+
24
+
25
+ class QKNorm(nn.Module):
26
+ """
27
+ RMS-normalize query and key across channel dimension with a learnable gain.
28
+ Applied per-head, per-position.
29
+ """
30
+
31
+ def __init__(self, head_channels: int, eps: float = 1e-6):
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.g_q = nn.Parameter(torch.ones(head_channels))
35
+ self.g_k = nn.Parameter(torch.ones(head_channels))
36
+
37
+ def forward(
38
+ self, q: torch.Tensor, k: torch.Tensor
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ """
41
+ Parameters
42
+ ----------
43
+ q : torch.Tensor
44
+ Query, shape (n_batch, n_heads, seq_len_q, head_channels)
45
+ k : torch.Tensor
46
+ Key, shape (n_batch, n_heads, seq_len_k, head_channels)
47
+ """
48
+
49
+ def _rmsnorm(x, g):
50
+ rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
51
+ return x * rms * g # Broadcast targets final dimension
52
+
53
+ return _rmsnorm(q, self.g_q), _rmsnorm(k, self.g_k)
tria/nn/pos_enc.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ ################################################################################
4
+ # Utilities for positional encoding
5
+ ################################################################################
6
+
7
+
8
+ def build_sinusoidal_cache(seq_len: int, n_channels: int, device, dtype):
9
+ """
10
+ Returns
11
+ -------
12
+ torch.Tensor
13
+ Cache, shape (seq_len, n_channels)
14
+ """
15
+ assert n_channels % 2 == 0
16
+ pos = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1) # (seq_len, 1)
17
+ i = torch.arange(n_channels // 2, device=device, dtype=dtype).unsqueeze(
18
+ 0
19
+ ) # (1, n_channels/2)
20
+ inv_freq = 1.0 / (10000 ** (i / (n_channels // 2)))
21
+ ang = pos * inv_freq # (seq_len, n_channels/2)
22
+ emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=1) # (seq_len, n_channels)
23
+ return emb
24
+
25
+
26
+ def apply_sinusoidal(x: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
27
+ """
28
+ Parameters
29
+ ----------
30
+ x : torch.Tensor
31
+ Shape (n_batch, n_heads, seq_len, head_channels) or (n_batch, seq_len, n_channels)
32
+ cache: torch.Tensor
33
+ Shape (seq_len, n_channels)
34
+
35
+ Returns
36
+ -------
37
+ torch.Tensor
38
+ Shape (n_batch, n_heads, seq_len, head_channels) or (n_batch, seq_len, n_channels)
39
+ """
40
+ if x.ndim == 4:
41
+ n_batch, n_heads, seq_len, head_channels = x.shape
42
+ return x + cache.to(x.dtype)[None, None, :seq_len, :head_channels]
43
+ elif x.ndim == 3:
44
+ n_batch, seq_len, n_channels = x.shape
45
+ return x + cache.to(x.dtype)[None, :seq_len, :n_channels]
46
+ else:
47
+ raise ValueError(
48
+ f"Invalid input shape {tuple(x.shape)}; "
49
+ f"expected (n_batch, [n_heads], seq_len, n_channels)"
50
+ )
51
+
52
+
53
+ def build_rope_cache(
54
+ seq_len: int, n_channels: int, device, dtype, base: float = 10000.0
55
+ ):
56
+ """
57
+ Returns
58
+ ----------
59
+ torch.Tensor, torch.Tensor
60
+ Caches, shape (seq_len, n_channels/2)
61
+ """
62
+ assert n_channels % 2 == 0
63
+ theta = 1.0 / (
64
+ base
65
+ ** (torch.arange(0, n_channels, 2, device=device, dtype=dtype) / n_channels)
66
+ )
67
+ seq = torch.arange(seq_len, device=device, dtype=dtype)
68
+ freqs = torch.einsum("t,d->td", seq, theta) # (seq_len, n_channels/2)
69
+ return torch.cos(freqs), torch.sin(
70
+ freqs
71
+ ) # (seq_len, n_channels/2), (seq_len, n_channels/2)
72
+
73
+
74
+ def apply_rope(
75
+ q_or_k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ """
78
+ Parameters
79
+ ----------
80
+ q_or_k : torch.Tensor
81
+ Shape (n_batch, n_heads, seq_len, head_channels) where head_channels even
82
+ cos : torch.Tensor
83
+ Shape (seq_len, head_channels/2)
84
+ sin : torch.Tensor
85
+ Shape (seq_len, head_channels/2)
86
+
87
+ Returns
88
+ -------
89
+ torch.Tensor
90
+ Shape (n_batch, n_heads, seq_len, head_channels)
91
+ """
92
+ n_batch, n_heads, seq_len, head_channels = q_or_k.shape
93
+ q = q_or_k.reshape(n_batch, n_heads, seq_len, head_channels // 2, 2)
94
+ q1, q2 = q[..., 0], q[..., 1] # (n_batch, n_heads, seq_len, n_channels / 2)
95
+ c = cos[:seq_len].to(q_or_k.dtype)[None, None, :, :]
96
+ s = sin[:seq_len].to(q_or_k.dtype)[None, None, :, :]
97
+ out1 = q1 * c - q2 * s
98
+ out2 = q1 * s + q2 * c
99
+ return torch.stack([out1, out2], dim=-1).reshape(
100
+ n_batch, n_heads, seq_len, head_channels
101
+ )
tria/nn/transformer.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .attention import MultiheadAttention
8
+ from .norm import RMSNorm
9
+
10
+ ################################################################################
11
+ # Transformer
12
+ ################################################################################
13
+
14
+
15
+ def lengths_to_mask(
16
+ lengths: torch.Tensor, max_len: Optional[int] = None
17
+ ) -> torch.Tensor:
18
+ """
19
+ Parameters
20
+ ----------
21
+ lengths : torch.Tensor
22
+ Shape (n_batch,)
23
+ max_len : int
24
+ """
25
+ if max_len is None:
26
+ max_len = int(lengths.amax())
27
+ rng = torch.arange(max_len, device=lengths.device)
28
+ return rng[None, :] < lengths[:, None] # (n_batch, max_len)
29
+
30
+
31
+ class MLP(nn.Module):
32
+ def __init__(
33
+ self, n_channels: int, mult: int = 4, p_dropout: float = 0.1, bias: bool = True
34
+ ):
35
+ super().__init__()
36
+
37
+ self.mlp = nn.Sequential(
38
+ nn.Linear(n_channels, n_channels * mult),
39
+ nn.GELU(),
40
+ nn.Linear(n_channels * mult, n_channels),
41
+ nn.Dropout(p_dropout),
42
+ )
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ assert x.ndim == 3 # (n_batch, seq_len, n_channels)
46
+ return self.mlp(x) # (n_batch, seq_len, n_channels)
47
+
48
+
49
+ class TransformerBlock(nn.Module):
50
+ def __init__(
51
+ self,
52
+ n_channels: int,
53
+ n_heads: int,
54
+ mult: int = 4,
55
+ p_dropout: float = 0.0,
56
+ bias: bool = True,
57
+ max_len: int = 8192,
58
+ pos_enc_self_attn: Optional[str] = "rope",
59
+ pos_enc_cross_attn: Optional[str] = "absolute",
60
+ qk_norm: bool = True,
61
+ use_sdpa: bool = True,
62
+ cross_attn: bool = False,
63
+ norm: str = "layer",
64
+ ):
65
+ super().__init__()
66
+
67
+ assert norm in ["layer", "rms", "none", None]
68
+ if norm == "rms":
69
+ norm_cls = RMSNorm
70
+ elif norm == "layer":
71
+ norm_cls = nn.LayerNorm
72
+ else:
73
+ norm_cls = nn.Identity
74
+
75
+ self.norm_1 = norm_cls(n_channels)
76
+ self.self_attn = MultiheadAttention(
77
+ n_channels=n_channels,
78
+ n_heads=n_heads,
79
+ p_dropout=p_dropout,
80
+ bias=bias,
81
+ max_len=max_len,
82
+ pos_enc=pos_enc_self_attn,
83
+ qk_norm=qk_norm,
84
+ use_sdpa=use_sdpa,
85
+ )
86
+
87
+ self.cross_attn = cross_attn
88
+ if cross_attn:
89
+ self.norm_x = norm_cls(n_channels)
90
+ self.norm_c = norm_cls(n_channels)
91
+ self.cross = MultiheadAttention(
92
+ n_channels=n_channels,
93
+ n_heads=n_heads,
94
+ p_dropout=p_dropout,
95
+ bias=bias,
96
+ max_len=max_len,
97
+ pos_enc=pos_enc_cross_attn,
98
+ qk_norm=qk_norm,
99
+ use_sdpa=use_sdpa,
100
+ )
101
+
102
+ self.norm_2 = norm_cls(n_channels)
103
+ self.mlp = MLP(n_channels=n_channels, mult=mult, p_dropout=p_dropout, bias=bias)
104
+
105
+ def forward(
106
+ self,
107
+ x: torch.Tensor,
108
+ c: Optional[torch.Tensor] = None,
109
+ mask_x: Optional[torch.Tensor] = None,
110
+ mask_c: Optional[torch.Tensor] = None,
111
+ ) -> torch.Tensor:
112
+ """
113
+ Parameters
114
+ ----------
115
+ x : torch.Tensor
116
+ Input sequence, shape (n_batch, seq_len_x, n_channels)
117
+ c : torch.Tensor
118
+ Conditioning sequence, shape (n_batch, seq_len_c, n_channels)
119
+ mask_x : torch.Tensor
120
+ Boolean mask indicating valid positions in input sequence, shape
121
+ (n_batch, seq_len_x)
122
+ mask_c : torch.Tensor
123
+ Boolean mask indicating valid positions in conditioning sequence,
124
+ shape (n_batch, seq_len_c)
125
+ """
126
+
127
+ if self.cross_attn:
128
+ assert c is not None
129
+
130
+ # Self-attention
131
+ y = self.norm_1(x)
132
+ y = self.self_attn(y, y, y, mask_q=mask_x, mask_k=mask_x)
133
+ x = x + y
134
+
135
+ # Cross-attention
136
+ if self.cross_attn and c is not None:
137
+ q = self.norm_x(x)
138
+ k = self.norm_c(c)
139
+ v = k
140
+ y = self.cross(q, k, v, mask_q=mask_x, mask_k=mask_c)
141
+ x = x + y
142
+
143
+ # MLP
144
+ y = self.norm_2(x)
145
+ y = self.mlp(y)
146
+ x = x + y
147
+
148
+ # Zero invalid outputs
149
+ if mask_x is not None:
150
+ with torch.no_grad():
151
+ x.masked_fill_(~mask_x[:, :, None], 0.0)
152
+
153
+ return x
154
+
155
+
156
+ class Transformer(nn.Module):
157
+ def __init__(
158
+ self,
159
+ n_channels: int,
160
+ n_heads: int,
161
+ n_layers: int,
162
+ mult: int,
163
+ p_dropout: float = 0.0,
164
+ bias: bool = True,
165
+ max_len: int = 8192,
166
+ pos_enc_self_attn: Optional[str] = "rope",
167
+ pos_enc_cross_attn: Optional[str] = "absolute",
168
+ qk_norm: bool = True,
169
+ use_sdpa: bool = True,
170
+ cross_attn: bool = False,
171
+ ):
172
+ super().__init__()
173
+ self.layers = nn.ModuleList(
174
+ [
175
+ TransformerBlock(
176
+ n_channels=n_channels,
177
+ n_heads=n_heads,
178
+ mult=mult,
179
+ p_dropout=p_dropout,
180
+ bias=bias,
181
+ max_len=max_len,
182
+ pos_enc_self_attn=pos_enc_self_attn,
183
+ pos_enc_cross_attn=pos_enc_cross_attn,
184
+ qk_norm=qk_norm,
185
+ use_sdpa=use_sdpa,
186
+ cross_attn=cross_attn,
187
+ )
188
+ for _ in range(n_layers)
189
+ ]
190
+ )
191
+ self.n_channels = n_channels
192
+ self.max_len = max_len
193
+ self.pos_enc_self_attn = pos_enc_self_attn
194
+ self.pos_enc_cross_attn = pos_enc_cross_attn
195
+
196
+ @torch.no_grad()
197
+ def _masks_from_lengths(
198
+ self,
199
+ mask_x: Optional[torch.Tensor],
200
+ mask_c: Optional[torch.Tensor],
201
+ lengths_x: Optional[torch.Tensor],
202
+ lengths_c: Optional[torch.Tensor],
203
+ seq_len_x: int,
204
+ seq_len_c: Optional[int],
205
+ device,
206
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
207
+ if mask_x is None and lengths_x is not None:
208
+ mask_x = lengths_to_mask(lengths_x.to(device), seq_len_x)
209
+ if mask_c is None and lengths_c is not None:
210
+ assert seq_len_c is not None
211
+ mask_c = lengths_to_mask(lengths_c.to(device), seq_len_c)
212
+ if mask_x is not None:
213
+ mask_x = mask_x.bool()
214
+ if mask_c is not None:
215
+ mask_c = mask_c.bool()
216
+ return mask_x, mask_c
217
+
218
+ def forward(
219
+ self,
220
+ x: torch.Tensor,
221
+ c: Optional[torch.Tensor] = None,
222
+ mask_x: Optional[torch.Tensor] = None,
223
+ mask_c: Optional[torch.Tensor] = None,
224
+ lengths_x: Optional[torch.Tensor] = None,
225
+ lengths_c: Optional[torch.Tensor] = None,
226
+ ) -> torch.Tensor:
227
+ """
228
+ Parameters
229
+ ----------
230
+ x : torch.Tensor
231
+ Input sequence, shape (n_batch, seq_len_x, n_channels)
232
+ c : torch.Tensor
233
+ Conditioning sequence, shape (n_batch, seq_len_c, n_channels)
234
+ mask_x : torch.Tensor
235
+ Boolean mask indicating valid positions in input sequence, shape
236
+ (n_batch, seq_len_x)
237
+ mask_c : torch.Tensor
238
+ Boolean mask indicating valid positions in conditioning sequence,
239
+ shape (n_batch, seq_len_c)
240
+ lengths_x : torch.Tensor
241
+ Valid lengths of input sequences, shape (n_batch,)
242
+ lengths_c : torch.Tensor
243
+ Valid lengths of conditioning sequences, shape (n_batch,)
244
+ """
245
+
246
+ assert x.ndim == 3
247
+ n_batch, seq_len_x, n_channels = x.shape
248
+ assert n_channels == self.n_channels
249
+ seq_len_c = c.shape[1] if c is not None else None
250
+
251
+ # Create valid masks from lengths if necessary
252
+ mask_x, mask_c = self._masks_from_lengths(
253
+ mask_x, mask_c, lengths_x, lengths_c, seq_len_x, seq_len_c, x.device
254
+ )
255
+
256
+ for i, block in enumerate(self.layers):
257
+ x = block(x=x, c=c, mask_x=mask_x, mask_c=mask_c)
258
+
259
+ return x
tria/pipelines/__init__.py ADDED
File without changes
tria/pipelines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
tria/pipelines/tokenizer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .tokenizer import Tokenizer
2
+ from .tokenizer import TokenSequence
tria/pipelines/tokenizer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (242 Bytes). View file
 
tria/pipelines/tokenizer/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
tria/pipelines/tokenizer/dac/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-present, Descript
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
tria/pipelines/tokenizer/dac/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dac import DAC
tria/pipelines/tokenizer/dac/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
tria/pipelines/tokenizer/dac/__pycache__/dac.cpython-310.pyc ADDED
Binary file (5.77 kB). View file
 
tria/pipelines/tokenizer/dac/__pycache__/modules.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
tria/pipelines/tokenizer/dac/dac.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .modules import Decoder
10
+ from .modules import Encoder
11
+ from .modules import init_weights
12
+ from .nn.quantize import ResidualVectorQuantize
13
+
14
+ ################################################################################
15
+ # Descript Audio Codec (DAC)
16
+ ################################################################################
17
+
18
+
19
+ class DAC(torch.nn.Module):
20
+ """
21
+ Descript Audio Codec (DAC) proposed by Kumar et al. in "High-Fidelity Audio
22
+ Compression with Improved RVQGAN" (2023). Code adapted from:
23
+ https://github.com/descriptinc/descript-audio-codec
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ sample_rate: int = 44_100,
29
+ encoder_dim: int = 64,
30
+ encoder_rates: List[int] = (2, 4, 8, 8),
31
+ latent_dim: int = None,
32
+ decoder_dim: int = 1536,
33
+ decoder_rates: List[int] = (8, 8, 4, 2),
34
+ n_codebooks: int = 9,
35
+ codebook_size: int = 1024,
36
+ codebook_dim: Union[int, list] = 8,
37
+ quantizer_dropout: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.encoder_dim = encoder_dim
42
+ self.encoder_rates = encoder_rates
43
+ self.decoder_dim = decoder_dim
44
+ self.decoder_rates = decoder_rates
45
+ self.sample_rate = sample_rate
46
+
47
+ if latent_dim is None:
48
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
49
+ self.latent_dim = latent_dim
50
+
51
+ self.hop_length = np.prod(encoder_rates)
52
+
53
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
54
+
55
+ self.n_codebooks = n_codebooks
56
+ self.codebook_size = codebook_size
57
+ self.codebook_dim = codebook_dim
58
+ self.quantizer = ResidualVectorQuantize(
59
+ input_dim=latent_dim,
60
+ n_codebooks=n_codebooks,
61
+ codebook_size=codebook_size,
62
+ codebook_dim=codebook_dim,
63
+ quantizer_dropout=quantizer_dropout,
64
+ )
65
+
66
+ self.decoder = Decoder(
67
+ latent_dim,
68
+ decoder_dim,
69
+ decoder_rates,
70
+ )
71
+ self.apply(init_weights)
72
+
73
+ self.delay = self.get_delay()
74
+
75
+ # As long as we don't run chunked/segmented encoding and decoding,
76
+ # we can keep padding on
77
+ self.padding = True
78
+
79
+ @property
80
+ def padding(self):
81
+ if not hasattr(self, "_padding"):
82
+ self._padding = True
83
+ return self._padding
84
+
85
+ @padding.setter
86
+ def padding(self, value: bool):
87
+ assert isinstance(value, bool)
88
+
89
+ layers = [
90
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
91
+ ]
92
+
93
+ for layer in layers:
94
+ if value:
95
+ if hasattr(layer, "original_padding"):
96
+ layer.padding = layer.original_padding
97
+ else:
98
+ layer.original_padding = layer.padding
99
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
100
+
101
+ self._padding = value
102
+
103
+ def get_delay(self):
104
+ # Any number works here, delay is invariant to input length
105
+ l_out = self.get_output_length(0)
106
+ L = l_out
107
+
108
+ layers = []
109
+ for layer in self.modules():
110
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
111
+ layers.append(layer)
112
+
113
+ for layer in reversed(layers):
114
+ d = layer.dilation[0]
115
+ k = layer.kernel_size[0]
116
+ s = layer.stride[0]
117
+
118
+ if isinstance(layer, nn.ConvTranspose1d):
119
+ L = ((L - d * (k - 1) - 1) / s) + 1
120
+ elif isinstance(layer, nn.Conv1d):
121
+ L = (L - 1) * s + d * (k - 1) + 1
122
+
123
+ L = math.ceil(L)
124
+
125
+ l_in = L
126
+
127
+ return (l_in - l_out) // 2
128
+
129
+ def get_output_length(self, input_length: int):
130
+ L = input_length
131
+ # Calculate output length
132
+ for layer in self.modules():
133
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
134
+ d = layer.dilation[0]
135
+ k = layer.kernel_size[0]
136
+ s = layer.stride[0]
137
+
138
+ if isinstance(layer, nn.Conv1d):
139
+ L = ((L - d * (k - 1) - 1) / s) + 1
140
+ elif isinstance(layer, nn.ConvTranspose1d):
141
+ L = (L - 1) * s + d * (k - 1) + 1
142
+
143
+ L = math.floor(L)
144
+ return L
145
+
146
+ def encode(
147
+ self,
148
+ audio_data: torch.Tensor,
149
+ ):
150
+ """
151
+ Encode given audio data and return quantized latent codes.
152
+
153
+ Parameters
154
+ ----------
155
+ audio_data : torch.Tensor
156
+ Audio data to encode, shape (batch_size, 1, n_samples)
157
+
158
+ Returns
159
+ -------
160
+ codes:
161
+ Codebook indices across all quantizer levels, shape
162
+ (n_batch, n_quantizers, n_frames)
163
+ z_O: torch.Tensor
164
+ Quantized output obtained by summing projected quantized residuals
165
+ (z_o) over all quantizer levels, shape (n_batch, latent_dim, n_frames)
166
+ z_i: torch.Tensor
167
+ Continuous representation of inputs projected into codebook space,
168
+ shape (n_batch, n_quantizers, codebook_dim, n_frames). Note that
169
+ each quantizer level represents a predicted residual.
170
+ z_q: torch.Tensor
171
+ Quantized representation of input in codebook space, shape
172
+ (n_batch, n_quantizers, codebook_dim, n_frames). Note that each
173
+ quantizer level represents a quantized predicted residual.
174
+ z_o: torch.Tensor
175
+ Continuous representation of quantized input, projected back into
176
+ latent space, shape (n_batch, n_quantizers, latent_dim, n_frames).
177
+ Note that each quantizer level represents a projected quantized
178
+ predicted residual.
179
+ """
180
+ # Predict continuous latents
181
+ z = self.encoder(audio_data) # (n_batch, latent_dim, n_frames)
182
+ return *self.quantizer(z, n_quantizers=None), z
183
+
184
+ def decode(
185
+ self,
186
+ codes: torch.Tensor,
187
+ ):
188
+ """
189
+ Decode given quantized latent codes and return audio data
190
+
191
+ Parameters
192
+ ----------
193
+ codes : torch.Tensor
194
+ Quantized latent codes, shape (n_batch, n_quantizers, n_frames)
195
+
196
+ Returns
197
+ -------
198
+ torch.Tensor
199
+ Decoded audio data, shape (n_batch, 1, n_samples)
200
+ """
201
+ z_O = self.quantizer.from_codes(codes) # (n_batch, latent_dim, n_frames)
202
+ recons = self.decoder(z_O) # (n_batch, 1, n_samples)
203
+ return recons