CocoBro commited on
Commit
c14d03d
·
1 Parent(s): cdbb4cf

init space

Browse files
Files changed (43) hide show
  1. app.py +358 -63
  2. example/content.jsonl +1 -0
  3. losses/base.py +22 -0
  4. models/__pycache__/common.cpython-310.pyc +0 -0
  5. models/__pycache__/content_adapter.cpython-310.pyc +0 -0
  6. models/__pycache__/diffusion.cpython-310.pyc +0 -0
  7. models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc +0 -0
  8. models/autoencoder/autoencoder_base.py +22 -0
  9. models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc +0 -0
  10. models/autoencoder/waveform/dac.py +0 -0
  11. models/autoencoder/waveform/stable_vae.py +586 -0
  12. models/common.py +79 -0
  13. models/content_adapter.py +430 -0
  14. models/content_encoder/__pycache__/content_encoder.cpython-310.pyc +0 -0
  15. models/content_encoder/__pycache__/llm_encoder.cpython-310.pyc +0 -0
  16. models/content_encoder/content_encoder.py +133 -0
  17. models/content_encoder/llm_encoder.py +215 -0
  18. models/content_encoder/text_encoder.py +76 -0
  19. models/diffusion.py +401 -0
  20. models/dit/__init__.py +0 -0
  21. models/dit/__pycache__/__init__.cpython-310.pyc +0 -0
  22. models/dit/__pycache__/mmdit_back.cpython-310.pyc +0 -0
  23. models/dit/__pycache__/mmdit_layers.cpython-310.pyc +0 -0
  24. models/dit/__pycache__/modules.cpython-310.pyc +0 -0
  25. models/dit/attention.py +350 -0
  26. models/dit/mmdit_back.py +346 -0
  27. models/dit/mmdit_layers.py +421 -0
  28. models/dit/modules.py +445 -0
  29. models/dit/rotary.py +88 -0
  30. models/dit/span_mask.py +149 -0
  31. models/flow_matching.py +1082 -0
  32. requirements.txt +28 -0
  33. stabilityai/stable-diffusion-2-1/scheduler/scheduler_config.json +14 -0
  34. utils/__pycache__/config.cpython-310.pyc +0 -0
  35. utils/__pycache__/torch_utilities.cpython-310.pyc +0 -0
  36. utils/accelerate_utilities.py +13 -0
  37. utils/audio.py +58 -0
  38. utils/config.py +53 -0
  39. utils/diffsinger_utilities.py +551 -0
  40. utils/general.py +68 -0
  41. utils/logging.py +23 -0
  42. utils/lr_scheduler_utilities.py +154 -0
  43. utils/torch_utilities.py +288 -0
app.py CHANGED
@@ -1,70 +1,365 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import time
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Tuple, Optional, Dict, Any
9
+
10
  import gradio as gr
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+ import torchaudio
15
+ import librosa
16
+
17
+ import hydra
18
+ from omegaconf import OmegaConf
19
+ from safetensors.torch import load_file
20
+ import diffusers.schedulers as noise_schedulers
21
+ from huggingface_hub import snapshot_download
22
+
23
+ from models.common import LoadPretrainedBase
24
+ from utils.config import register_omegaconf_resolvers
25
+
26
+
27
+ # -----------------------------
28
+ # Logging
29
+ # -----------------------------
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format="%(asctime)s - %(levelname)s - %(message)s"
33
+ )
34
+ logger = logging.getLogger("mmedit_space")
35
+
36
+ register_omegaconf_resolvers()
37
+
38
+
39
+ # ---------------------------------------------------------
40
+ # HF Repo IDs(按你的默认需求)
41
+ # ---------------------------------------------------------
42
+ MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
43
+ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
44
+
45
+ QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
46
+ QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
47
+
48
+ OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
49
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
50
+
51
+ USE_AMP = os.environ.get("USE_AMP", "0") == "1"
52
+ AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
53
+
54
+ _PIPELINE_CACHE: Dict[str, Tuple[LoadPretrainedBase, object, int, torch.device]] = {}
55
+
56
+
57
+ # ---------------------------------------------------------
58
+ # 下载 repo
59
+ # ---------------------------------------------------------
60
+ def resolve_model_dirs() -> Tuple[Path, Path]:
61
  """
62
+ 返回:
63
+ repo_root: 你的 MMEdit repo 的本地目录(包含 config.yaml / model.safetensors / vae/)
64
+ qwen_root: Qwen2-Audio repo 的本地目录
65
  """
66
+ logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
67
+ repo_root = snapshot_download(
68
+ repo_id=MMEDIT_REPO_ID,
69
+ revision=MMEDIT_REVISION,
70
+ local_dir=None,
71
+ local_dir_use_symlinks=False,
72
+ )
73
+ repo_root = Path(repo_root).resolve()
74
+
75
+ logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
76
+ qwen_root = snapshot_download(
77
+ repo_id=QWEN_REPO_ID,
78
+ revision=QWEN_REVISION,
79
+ local_dir=None,
80
+ local_dir_use_symlinks=False,
81
+ )
82
+ qwen_root = Path(qwen_root).resolve()
83
+
84
+ return repo_root, qwen_root
85
+
86
+
87
+ # ---------------------------------------------------------
88
+ # 你的音频加载(按你要求:orig -> 16k -> target_sr)
89
+ # ---------------------------------------------------------
90
+ def load_and_process_audio(audio_path: str, target_sr: int) -> torch.Tensor:
91
+ path = Path(audio_path)
92
+ if not path.exists():
93
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
94
+
95
+ waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
96
+
97
+ # Convert to mono
98
+ if waveform.ndim == 2:
99
+ waveform = waveform.mean(dim=0) # (T,)
100
+ elif waveform.ndim > 2:
101
+ waveform = waveform.reshape(-1)
102
+
103
+ if target_sr and int(target_sr) != int(orig_sr):
104
+ waveform_np = waveform.cpu().numpy()
105
+
106
+ # 1) 先到 16k
107
+ sr_mid = 16000
108
+ if int(orig_sr) != sr_mid:
109
+ waveform_np = librosa.resample(
110
+ waveform_np,
111
+ orig_sr=int(orig_sr),
112
+ target_sr=sr_mid
113
+ )
114
+ orig_sr_mid = sr_mid
115
+ else:
116
+ orig_sr_mid = int(orig_sr)
117
+
118
+ # 2) 再到 target_sr(如 24k)
119
+ if int(target_sr) != orig_sr_mid:
120
+ waveform_np = librosa.resample(
121
+ waveform_np,
122
+ orig_sr=orig_sr_mid,
123
+ target_sr=int(target_sr)
124
+ )
125
+
126
+ waveform = torch.from_numpy(waveform_np)
127
+
128
+ return waveform
129
+
130
+
131
+ # ---------------------------------------------------------
132
+ # 校验 repo 结构
133
+ # ---------------------------------------------------------
134
+ def assert_repo_layout(repo_root: Path) -> None:
135
+ must = [
136
+ repo_root / "config.yaml",
137
+ repo_root / "model.safetensors",
138
+ repo_root / "vae",
139
+ ]
140
+ for p in must:
141
+ if not p.exists():
142
+ raise FileNotFoundError(f"Missing required path: {p}")
143
+
144
+ vae_files = list((repo_root / "vae").glob("*.ckpt"))
145
+ if len(vae_files) == 0:
146
+ raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
147
+
148
+
149
+ # ---------------------------------------------------------
150
+ # 关键:适配你这个 config.yaml 的路径写法
151
+ # ---------------------------------------------------------
152
+ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
153
+ """
154
+ 适配你 config.yaml:
155
+ - pretrained_ckpt: ckpt/mmedit/vae/epoch=xx.ckpt -> repo_root/vae/epoch=xx.ckpt
156
+ - model_path: ckpt/qwen2-audio-7B-instruct -> qwen_root (snapshot_download 结果)
157
+ """
158
+
159
+ # ---- 1) VAE ckpt ----
160
+ vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
161
+ if vae_ckpt:
162
+ vae_ckpt = str(vae_ckpt).replace("\\", "/")
163
+
164
+ # 你这里最稳定的做法:找到 "vae/" 子串之后的后缀
165
+ # 例如:
166
+ # ckpt/mmedit/vae/epoch=13-step=1000000.ckpt -> vae/epoch=13-step=1000000.ckpt
167
+ idx = vae_ckpt.find("vae/")
168
+ if idx != -1:
169
+ vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
170
+ else:
171
+ # 兜底:如果有人直接写 epoch=xx.ckpt,那就放到 repo_root/vae/
172
+ # 或者写 vae/xxx.ckpt
173
+ if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
174
+ vae_rel = f"vae/{vae_ckpt}"
175
+ else:
176
+ vae_rel = vae_ckpt
177
+
178
+ vae_path = (repo_root / vae_rel).resolve()
179
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
180
+
181
+ if not vae_path.exists():
182
+ raise FileNotFoundError(
183
+ f"VAE ckpt not found after patch:\n"
184
+ f" original: {vae_ckpt}\n"
185
+ f" patched : {vae_path}\n"
186
+ f"Repo root: {repo_root}\n"
187
+ f"Expected: {repo_root/'vae'/'*.ckpt'}"
188
+ )
189
+
190
+ # ---- 2) Qwen2-Audio model_path ----
191
+ # 你的 config 里写的是 ckpt/qwen2-audio-7B-instruct,但 Space 上我们直接用下载后的 qwen_root
192
+ exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
193
+
194
+
195
+ # ---------------------------------------------------------
196
+ # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
197
+ # ---------------------------------------------------------
198
+ def build_scheduler(exp_cfg: Dict[str, Any]):
199
+ name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
200
+ scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler")
201
+ return scheduler
202
+
203
+
204
+ def _amp_ctx(device: torch.device):
205
+ if not USE_AMP:
206
+ return torch.autocast("cuda", enabled=False)
207
+ if device.type != "cuda":
208
+ return torch.autocast("cpu", enabled=False)
209
+ dtype = torch.bfloat16 if AMP_DTYPE.lower() == "bf16" else torch.float16
210
+ return torch.autocast("cuda", dtype=dtype, enabled=True)
211
+
212
+
213
+ # ---------------------------------------------------------
214
+ # 冷启动:load+cache pipeline
215
+ # ---------------------------------------------------------
216
+ def load_pipeline() -> Tuple[LoadPretrainedBase, object, int, torch.device]:
217
+ cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
218
+ if cache_key in _PIPELINE_CACHE:
219
+ return _PIPELINE_CACHE[cache_key]
220
+
221
+ repo_root, qwen_root = resolve_model_dirs()
222
+ assert_repo_layout(repo_root)
223
+
224
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+ logger.info(f"repo_root = {repo_root}")
226
+ logger.info(f"device = {device}")
227
+ logger.info(f"qwen_root = {qwen_root}")
228
+
229
+ exp_cfg = OmegaConf.load(repo_root / "config.yaml")
230
+ exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
231
+
232
+ patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
233
+ logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
234
+ logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
235
+
236
+ model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
237
+
238
+ ckpt_path = repo_root / "model.safetensors"
239
+ sd = load_file(str(ckpt_path))
240
+ model.load_pretrained(sd)
241
+
242
+ model = model.to(device).eval()
243
+
244
+ scheduler = build_scheduler(exp_cfg)
245
+ target_sr = int(exp_cfg.get("sample_rate", 24000))
246
+
247
+ _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr, device)
248
+ logger.info("Pipeline loaded and cached.")
249
+ return model, scheduler, target_sr, device
250
+
251
+
252
+ # ---------------------------------------------------------
253
+ # 推理:audio + caption -> edited audio
254
+ # ---------------------------------------------------------
255
+ @torch.no_grad()
256
+ def run_edit(
257
+ audio_file: str,
258
+ caption: str,
259
+ num_steps: int,
260
+ guidance_scale: float,
261
+ guidance_rescale: float,
262
+ seed: int,
263
+ ) -> Tuple[Optional[str], str]:
264
+ if audio_file is None or not Path(audio_file).exists():
265
+ return None, "Error: please upload an audio file."
266
+
267
+ caption = (caption or "").strip()
268
+ if not caption:
269
+ return None, "Error: caption is empty."
270
+
271
+ model, scheduler, target_sr, device = load_pipeline()
272
+
273
+ seed = int(seed)
274
+ torch.manual_seed(seed)
275
+ np.random.seed(seed)
276
+
277
+ wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device)
278
+
279
+ batch = {
280
+ "audio_id": [Path(audio_file).stem],
281
+ "content": [{"audio": wav, "caption": caption}],
282
+ "task": ["audio_editing"],
283
+ }
284
+
285
+ # 和你给的 infer.config 对齐
286
+ kwargs = {
287
+ "num_steps": int(num_steps),
288
+ "guidance_scale": float(guidance_scale),
289
+ "guidance_rescale": float(guidance_rescale),
290
+ "use_gt_duration": False,
291
+ "mask_time_aligned_content": False,
292
+ }
293
+ kwargs.update(batch)
294
+
295
+ t0 = time.time()
296
+ with _amp_ctx(device):
297
+ out = model.inference(scheduler=scheduler, **kwargs)
298
+ dt = time.time() - t0
299
+
300
+ out_audio = out[0, 0].detach().float().cpu().numpy()
301
+ out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
302
+ sf.write(str(out_path), out_audio, samplerate=target_sr)
303
+
304
+ return str(out_path), f"OK | saved={out_path.name} | time={dt:.2f}s | sr={target_sr} | seed={seed}"
305
+
306
+
307
+ # ---------------------------------------------------------
308
+ # UI
309
+ # ---------------------------------------------------------
310
+ def build_demo():
311
+ with gr.Blocks(title="MMEdit Space Simulator") as demo:
312
+ gr.Markdown("# MMEdit Space 模拟(audio + caption → edited audio)")
313
+ gr.Markdown(
314
+ "点下面的示例即可自动填充音频路径与编辑指令,然后点击 Run Editing。"
315
+ )
316
+
317
+ with gr.Row():
318
+ with gr.Column():
319
+ audio_in = gr.Audio(label="Input Audio", type="filepath")
320
+ caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
321
+
322
+ # 一键填充示例:点一下就把 audio_in + caption 填好
323
+ gr.Examples(
324
+ label="example inputs",
325
+ examples=[
326
+ ["example/Ym8O802VvJes.wav", "Mix in dog barking in the middle."],
327
+ ],
328
+ inputs=[audio_in, caption],
329
+ cache_examples=False, # 本地/Space 都更稳,不提前缓存
330
+ )
331
+
332
+ with gr.Row():
333
+ num_steps = gr.Slider(1, 100, value=50, step=1, label="num_steps")
334
+ guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="guidance_scale")
335
+
336
+ with gr.Row():
337
+ guidance_rescale = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="guidance_rescale")
338
+ seed = gr.Number(value=42, precision=0, label="seed")
339
+
340
+ run_btn = gr.Button("Run Editing", variant="primary")
341
+
342
+ with gr.Column():
343
+ audio_out = gr.Audio(label="Edited Audio", type="filepath")
344
+ status = gr.Textbox(label="Status")
345
+
346
+ run_btn.click(
347
+ fn=run_edit,
348
+ inputs=[audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed],
349
+ outputs=[audio_out, status],
350
+ )
351
+
352
+ gr.Markdown(
353
+ "## 注意事项\n"
354
+ "- 首次加载较慢\n"
355
+ "- Space 上有一些bug,某些情况会损失原始音频\n"
356
+ )
357
+
358
+ return demo
359
 
 
 
 
 
360
 
361
 
362
  if __name__ == "__main__":
363
+ demo = build_demo()
364
+ port = int(os.environ.get("PORT", "7860")) # Space 默认 7860
365
+ demo.launch(server_name="0.0.0.0", server_port=port, share=False)
example/content.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"audio_id": "add_audiocaps_1", "content": "example/Ym8O802VvJes.wav", "caption": "Mix in dog barking in the middle."}
losses/base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class IndentityWrapper(nn.Module):
6
+ def forward(self, loss: torch.Tensor) -> dict[str, torch.Tensor]:
7
+ return {"loss": loss}
8
+
9
+
10
+ class LossSumWrapper(nn.Module):
11
+ def __init__(self, weights: dict[str, float]):
12
+ super().__init__()
13
+ self.weights = weights
14
+
15
+ def forward(self,
16
+ loss_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
17
+ total_loss = 0
18
+ for loss_name, loss_val in loss_dict.items():
19
+ total_loss += loss_val * self.weights[loss_name]
20
+ output = {"loss": total_loss}
21
+ output.update(loss_dict)
22
+ return output
models/__pycache__/common.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
models/__pycache__/content_adapter.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
models/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (9.77 kB). View file
 
models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
models/autoencoder/autoencoder_base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class AutoEncoderBase(ABC):
8
+ def __init__(
9
+ self, downsampling_ratio: int, sample_rate: int,
10
+ latent_shape: Sequence[int | None]
11
+ ):
12
+ self.downsampling_ratio = downsampling_ratio
13
+ self.sample_rate = sample_rate
14
+ self.latent_token_rate = sample_rate // downsampling_ratio
15
+ self.latent_shape = latent_shape
16
+ self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
17
+
18
+ @abstractmethod
19
+ def encode(
20
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
21
+ ) -> tuple[torch.Tensor, torch.Tensor]:
22
+ ...
models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
models/autoencoder/waveform/dac.py ADDED
File without changes
models/autoencoder/waveform/stable_vae.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Callable
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.utils import weight_norm
8
+ import torchaudio
9
+ from alias_free_torch import Activation1d
10
+
11
+ from models.common import LoadPretrainedBase
12
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
13
+ from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
14
+
15
+
16
+ # jit script make it 1.4x faster and save GPU memory
17
+ @torch.jit.script
18
+ def snake_beta(x, alpha, beta):
19
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
20
+
21
+
22
+ class SnakeBeta(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_features,
26
+ alpha=1.0,
27
+ alpha_trainable=True,
28
+ alpha_logscale=True
29
+ ):
30
+ super(SnakeBeta, self).__init__()
31
+ self.in_features = in_features
32
+
33
+ # initialize alpha
34
+ self.alpha_logscale = alpha_logscale
35
+ if self.alpha_logscale:
36
+ # log scale alphas initialized to zeros
37
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
38
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
39
+ else:
40
+ # linear scale alphas initialized to ones
41
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
42
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+ self.beta.requires_grad = alpha_trainable
46
+
47
+ # self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
51
+ # line up with x to [B, C, T]
52
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
53
+ if self.alpha_logscale:
54
+ alpha = torch.exp(alpha)
55
+ beta = torch.exp(beta)
56
+ x = snake_beta(x, alpha, beta)
57
+
58
+ return x
59
+
60
+
61
+ def WNConv1d(*args, **kwargs):
62
+ return weight_norm(nn.Conv1d(*args, **kwargs))
63
+
64
+
65
+ def WNConvTranspose1d(*args, **kwargs):
66
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
67
+
68
+
69
+ def get_activation(
70
+ activation: Literal["elu", "snake", "none"],
71
+ antialias=False,
72
+ channels=None
73
+ ) -> nn.Module:
74
+ if activation == "elu":
75
+ act = nn.ELU()
76
+ elif activation == "snake":
77
+ act = SnakeBeta(channels)
78
+ elif activation == "none":
79
+ act = nn.Identity()
80
+ else:
81
+ raise ValueError(f"Unknown activation {activation}")
82
+
83
+ if antialias:
84
+ act = Activation1d(act)
85
+
86
+ return act
87
+
88
+
89
+ class ResidualUnit(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels,
93
+ out_channels,
94
+ dilation,
95
+ use_snake=False,
96
+ antialias_activation=False
97
+ ):
98
+ super().__init__()
99
+
100
+ self.dilation = dilation
101
+
102
+ padding = (dilation * (7 - 1)) // 2
103
+
104
+ self.layers = nn.Sequential(
105
+ get_activation(
106
+ "snake" if use_snake else "elu",
107
+ antialias=antialias_activation,
108
+ channels=out_channels
109
+ ),
110
+ WNConv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=7,
114
+ dilation=dilation,
115
+ padding=padding
116
+ ),
117
+ get_activation(
118
+ "snake" if use_snake else "elu",
119
+ antialias=antialias_activation,
120
+ channels=out_channels
121
+ ),
122
+ WNConv1d(
123
+ in_channels=out_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=1
126
+ )
127
+ )
128
+
129
+ def forward(self, x):
130
+ res = x
131
+
132
+ #x = checkpoint(self.layers, x)
133
+ x = self.layers(x)
134
+
135
+ return x + res
136
+
137
+
138
+ class EncoderBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_channels,
142
+ out_channels,
143
+ stride,
144
+ use_snake=False,
145
+ antialias_activation=False
146
+ ):
147
+ super().__init__()
148
+
149
+ self.layers = nn.Sequential(
150
+ ResidualUnit(
151
+ in_channels=in_channels,
152
+ out_channels=in_channels,
153
+ dilation=1,
154
+ use_snake=use_snake
155
+ ),
156
+ ResidualUnit(
157
+ in_channels=in_channels,
158
+ out_channels=in_channels,
159
+ dilation=3,
160
+ use_snake=use_snake
161
+ ),
162
+ ResidualUnit(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ dilation=9,
166
+ use_snake=use_snake
167
+ ),
168
+ get_activation(
169
+ "snake" if use_snake else "elu",
170
+ antialias=antialias_activation,
171
+ channels=in_channels
172
+ ),
173
+ WNConv1d(
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ kernel_size=2 * stride,
177
+ stride=stride,
178
+ padding=math.ceil(stride / 2)
179
+ ),
180
+ )
181
+
182
+ def forward(self, x):
183
+ return self.layers(x)
184
+
185
+
186
+ class DecoderBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_channels,
190
+ out_channels,
191
+ stride,
192
+ use_snake=False,
193
+ antialias_activation=False,
194
+ use_nearest_upsample=False
195
+ ):
196
+ super().__init__()
197
+
198
+ if use_nearest_upsample:
199
+ upsample_layer = nn.Sequential(
200
+ nn.Upsample(scale_factor=stride, mode="nearest"),
201
+ WNConv1d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=2 * stride,
205
+ stride=1,
206
+ bias=False,
207
+ padding='same'
208
+ )
209
+ )
210
+ else:
211
+ upsample_layer = WNConvTranspose1d(
212
+ in_channels=in_channels,
213
+ out_channels=out_channels,
214
+ kernel_size=2 * stride,
215
+ stride=stride,
216
+ padding=math.ceil(stride / 2)
217
+ )
218
+
219
+ self.layers = nn.Sequential(
220
+ get_activation(
221
+ "snake" if use_snake else "elu",
222
+ antialias=antialias_activation,
223
+ channels=in_channels
224
+ ),
225
+ upsample_layer,
226
+ ResidualUnit(
227
+ in_channels=out_channels,
228
+ out_channels=out_channels,
229
+ dilation=1,
230
+ use_snake=use_snake
231
+ ),
232
+ ResidualUnit(
233
+ in_channels=out_channels,
234
+ out_channels=out_channels,
235
+ dilation=3,
236
+ use_snake=use_snake
237
+ ),
238
+ ResidualUnit(
239
+ in_channels=out_channels,
240
+ out_channels=out_channels,
241
+ dilation=9,
242
+ use_snake=use_snake
243
+ ),
244
+ )
245
+
246
+ def forward(self, x):
247
+ return self.layers(x)
248
+
249
+
250
+ class OobleckEncoder(nn.Module):
251
+ def __init__(
252
+ self,
253
+ in_channels=2,
254
+ channels=128,
255
+ latent_dim=32,
256
+ c_mults=[1, 2, 4, 8],
257
+ strides=[2, 4, 8, 8],
258
+ use_snake=False,
259
+ antialias_activation=False
260
+ ):
261
+ super().__init__()
262
+
263
+ c_mults = [1] + c_mults
264
+
265
+ self.depth = len(c_mults)
266
+
267
+ layers = [
268
+ WNConv1d(
269
+ in_channels=in_channels,
270
+ out_channels=c_mults[0] * channels,
271
+ kernel_size=7,
272
+ padding=3
273
+ )
274
+ ]
275
+
276
+ for i in range(self.depth - 1):
277
+ layers += [
278
+ EncoderBlock(
279
+ in_channels=c_mults[i] * channels,
280
+ out_channels=c_mults[i + 1] * channels,
281
+ stride=strides[i],
282
+ use_snake=use_snake
283
+ )
284
+ ]
285
+
286
+ layers += [
287
+ get_activation(
288
+ "snake" if use_snake else "elu",
289
+ antialias=antialias_activation,
290
+ channels=c_mults[-1] * channels
291
+ ),
292
+ WNConv1d(
293
+ in_channels=c_mults[-1] * channels,
294
+ out_channels=latent_dim,
295
+ kernel_size=3,
296
+ padding=1
297
+ )
298
+ ]
299
+
300
+ self.layers = nn.Sequential(*layers)
301
+
302
+ def forward(self, x):
303
+ return self.layers(x)
304
+
305
+
306
+ class OobleckDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ out_channels=2,
310
+ channels=128,
311
+ latent_dim=32,
312
+ c_mults=[1, 2, 4, 8],
313
+ strides=[2, 4, 8, 8],
314
+ use_snake=False,
315
+ antialias_activation=False,
316
+ use_nearest_upsample=False,
317
+ final_tanh=True
318
+ ):
319
+ super().__init__()
320
+
321
+ c_mults = [1] + c_mults
322
+
323
+ self.depth = len(c_mults)
324
+
325
+ layers = [
326
+ WNConv1d(
327
+ in_channels=latent_dim,
328
+ out_channels=c_mults[-1] * channels,
329
+ kernel_size=7,
330
+ padding=3
331
+ ),
332
+ ]
333
+
334
+ for i in range(self.depth - 1, 0, -1):
335
+ layers += [
336
+ DecoderBlock(
337
+ in_channels=c_mults[i] * channels,
338
+ out_channels=c_mults[i - 1] * channels,
339
+ stride=strides[i - 1],
340
+ use_snake=use_snake,
341
+ antialias_activation=antialias_activation,
342
+ use_nearest_upsample=use_nearest_upsample
343
+ )
344
+ ]
345
+
346
+ layers += [
347
+ get_activation(
348
+ "snake" if use_snake else "elu",
349
+ antialias=antialias_activation,
350
+ channels=c_mults[0] * channels
351
+ ),
352
+ WNConv1d(
353
+ in_channels=c_mults[0] * channels,
354
+ out_channels=out_channels,
355
+ kernel_size=7,
356
+ padding=3,
357
+ bias=False
358
+ ),
359
+ nn.Tanh() if final_tanh else nn.Identity()
360
+ ]
361
+
362
+ self.layers = nn.Sequential(*layers)
363
+
364
+ def forward(self, x):
365
+ return self.layers(x)
366
+
367
+
368
+ class Bottleneck(nn.Module):
369
+ def __init__(self, is_discrete: bool = False):
370
+ super().__init__()
371
+
372
+ self.is_discrete = is_discrete
373
+
374
+ def encode(self, x, return_info=False, **kwargs):
375
+ raise NotImplementedError
376
+
377
+ def decode(self, x):
378
+ raise NotImplementedError
379
+
380
+
381
+ @torch.jit.script
382
+ def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
383
+ stdev = nn.functional.softplus(scale) + 1e-4
384
+ var = stdev * stdev
385
+ logvar = torch.log(var)
386
+ latents = torch.randn_like(mean) * stdev + mean
387
+
388
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
389
+ return {"latents": latents, "kl": kl}
390
+
391
+
392
+ class VAEBottleneck(Bottleneck):
393
+ def __init__(self):
394
+ super().__init__(is_discrete=False)
395
+
396
+ def encode(self,
397
+ x,
398
+ return_info=False,
399
+ **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
400
+ mean, scale = x.chunk(2, dim=1)
401
+ sampled = vae_sample(mean, scale)
402
+
403
+ if return_info:
404
+ return sampled["latents"], {"kl": sampled["kl"]}
405
+ else:
406
+ return sampled["latents"]
407
+
408
+ def decode(self, x):
409
+ return x
410
+
411
+
412
+ def compute_mean_kernel(x, y):
413
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
414
+ return torch.exp(-kernel_input).mean()
415
+
416
+
417
+ class Pretransform(nn.Module):
418
+ def __init__(self, enable_grad, io_channels, is_discrete):
419
+ super().__init__()
420
+
421
+ self.is_discrete = is_discrete
422
+ self.io_channels = io_channels
423
+ self.encoded_channels = None
424
+ self.downsampling_ratio = None
425
+
426
+ self.enable_grad = enable_grad
427
+
428
+ def encode(self, x):
429
+ raise NotImplementedError
430
+
431
+ def decode(self, z):
432
+ raise NotImplementedError
433
+
434
+ def tokenize(self, x):
435
+ raise NotImplementedError
436
+
437
+ def decode_tokens(self, tokens):
438
+ raise NotImplementedError
439
+
440
+
441
+ class StableVAE(LoadPretrainedBase, AutoEncoderBase):
442
+ def __init__(
443
+ self,
444
+ encoder,
445
+ decoder,
446
+ latent_dim,
447
+ downsampling_ratio,
448
+ sample_rate,
449
+ io_channels=2,
450
+ bottleneck: Bottleneck = None,
451
+ pretransform: Pretransform = None,
452
+ in_channels=None,
453
+ out_channels=None,
454
+ soft_clip=False,
455
+ pretrained_ckpt: str | Path = None
456
+ ):
457
+ LoadPretrainedBase.__init__(self)
458
+ AutoEncoderBase.__init__(
459
+ self,
460
+ downsampling_ratio=downsampling_ratio,
461
+ sample_rate=sample_rate,
462
+ latent_shape=(latent_dim, None)
463
+ )
464
+
465
+ self.latent_dim = latent_dim
466
+ self.io_channels = io_channels
467
+ self.in_channels = io_channels
468
+ self.out_channels = io_channels
469
+ self.min_length = self.downsampling_ratio
470
+
471
+ if in_channels is not None:
472
+ self.in_channels = in_channels
473
+
474
+ if out_channels is not None:
475
+ self.out_channels = out_channels
476
+
477
+ self.bottleneck = bottleneck
478
+ self.encoder = encoder
479
+ self.decoder = decoder
480
+ self.pretransform = pretransform
481
+ self.soft_clip = soft_clip
482
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
483
+
484
+ self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
485
+ "autoencoder."
486
+ )
487
+ if pretrained_ckpt is not None:
488
+ self.load_pretrained(pretrained_ckpt)
489
+
490
+ def process_state_dict(self, model_dict, state_dict):
491
+ state_dict = state_dict["state_dict"]
492
+ state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
493
+ return state_dict
494
+
495
+ def encode(
496
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor,pad_latent_len: int = 500
497
+ ) -> tuple[torch.Tensor, torch.Tensor]:
498
+ # import pdb;pdb.set_trace()
499
+ z = self.encoder(waveform)
500
+ z = self.bottleneck.encode(z)
501
+ z_length = waveform_lengths // self.downsampling_ratio
502
+ z_mask = create_mask_from_length(z_length, max_length=pad_latent_len)
503
+
504
+ B, C, L = z.shape
505
+ if L < pad_latent_len:
506
+ pad_size = pad_latent_len - L
507
+ z = torch.cat([z, torch.zeros(B, C, pad_size, device=z.device, dtype=z.dtype)], dim=-1)
508
+ return z, z_mask
509
+
510
+ def decode(self, latents: torch.Tensor, latent_mask: torch.Tensor | None = None) -> torch.Tensor:
511
+ """
512
+ latents: [B, C, T_latent]
513
+ latent_mask: [B, T_latent] 可选,1为有效,0为padding
514
+ """
515
+ if latent_mask is not None:
516
+ outputs = []
517
+ for b in range(latents.size(0)):
518
+ # 找到当前样本有效的时间步索引
519
+ valid_idx = latent_mask[b].bool()
520
+ valid_latents = latents[b, :, valid_idx] # [C, T_valid]
521
+ outputs.append(self.decoder(valid_latents.unsqueeze(0))) # [1, C, T_waveform_valid]
522
+ return torch.cat(outputs, dim=0)
523
+ else:
524
+ return self.decoder(latents)
525
+ return waveform
526
+
527
+
528
+
529
+ class StableVAEProjectorWrapper(nn.Module):
530
+ def __init__(
531
+ self,
532
+ vae_dim: int,
533
+ embed_dim: int,
534
+ model: StableVAE | None = None,
535
+ ):
536
+ super().__init__()
537
+ self.model = model
538
+ self.proj = nn.Linear(vae_dim, embed_dim)
539
+
540
+ def forward(
541
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
542
+ ) -> tuple[torch.Tensor, torch.Tensor]:
543
+ self.model.eval()
544
+ with torch.no_grad():
545
+ z, z_mask = self.model.encode(waveform, waveform_lengths, pad_latent_len=500)
546
+ z = self.proj(z.transpose(1, 2))
547
+ return {"output": z, "mask": z_mask}
548
+
549
+
550
+ if __name__ == '__main__':
551
+ import hydra
552
+ from utils.config import generate_config_from_command_line_overrides
553
+ model_config = generate_config_from_command_line_overrides(
554
+ "../../../configs"
555
+ )
556
+ autoencoder: StableVAE = hydra.utils.instantiate(model_config)
557
+ autoencoder.eval()
558
+
559
+ waveform, sr = torchaudio.load(
560
+ "/edit/syn_7.wav"
561
+ )
562
+ waveform = waveform.mean(0, keepdim=True)
563
+ waveform = torchaudio.functional.resample(
564
+ waveform, sr, model_config["sample_rate"]
565
+ )
566
+ import soundfile as sf
567
+ sf.write(
568
+ "./torch_test.wav",
569
+ waveform[0].numpy(),
570
+ samplerate=model_config["sample_rate"]
571
+ )
572
+ print("waveform: ", waveform.shape)
573
+ with torch.no_grad():
574
+ latent, latent_length = autoencoder.encode(
575
+ waveform, torch.as_tensor([waveform.shape[-1]])
576
+ )
577
+ print("latent: ", latent.shape)
578
+ print("latent_length: ", latent_length)
579
+ reconstructed = autoencoder.decode(latent, latent_length)
580
+ print("reconstructed: ", reconstructed.shape)
581
+
582
+ sf.write(
583
+ "./reconstructed.wav",
584
+ reconstructed[0, 0].numpy(),
585
+ samplerate=model_config["sample_rate"]
586
+ )
models/common.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from utils.torch_utilities import (
8
+ load_pretrained_model, merge_matched_keys, create_mask_from_length,
9
+ loss_with_mask, create_alignment_path
10
+ )
11
+
12
+
13
+ class LoadPretrainedBase(nn.Module):
14
+ def process_state_dict(
15
+ self, model_dict: dict[str, torch.Tensor],
16
+ state_dict: dict[str, torch.Tensor]
17
+ ):
18
+ """
19
+ Custom processing functions of each model that transforms `state_dict` loaded from
20
+ checkpoints to the state that can be used in `load_state_dict`.
21
+ Use `merge_mathced_keys` to update parameters with matched names and shapes by
22
+ default.
23
+
24
+ Args
25
+ model_dict:
26
+ The state dict of the current model, which is going to load pretrained parameters
27
+ state_dict:
28
+ A dictionary of parameters from a pre-trained model.
29
+
30
+ Returns:
31
+ dict[str, torch.Tensor]:
32
+ The updated state dict, where parameters with matched keys and shape are
33
+ updated with values in `state_dict`.
34
+ """
35
+ state_dict = merge_matched_keys(model_dict, state_dict)
36
+ return state_dict
37
+
38
+ def load_pretrained(self, ckpt_path: str | Path):
39
+ load_pretrained_model(
40
+ self, ckpt_path, state_dict_process_fn=self.process_state_dict
41
+ )
42
+
43
+
44
+ class CountParamsBase(nn.Module):
45
+ def count_params(self):
46
+ num_params = 0
47
+ trainable_params = 0
48
+ for param in self.parameters():
49
+ num_params += param.numel()
50
+ if param.requires_grad:
51
+ trainable_params += param.numel()
52
+ return num_params, trainable_params
53
+
54
+
55
+ class SaveTrainableParamsBase(nn.Module):
56
+ @property
57
+ def param_names_to_save(self):
58
+ names = []
59
+ for name, param in self.named_parameters():
60
+ if param.requires_grad:
61
+ names.append(name)
62
+ for name, _ in self.named_buffers():
63
+ names.append(name)
64
+ return names
65
+
66
+ def load_state_dict(self, state_dict, strict=True):
67
+ missing_keys = []
68
+ for key in self.param_names_to_save:
69
+ if key not in state_dict:
70
+ missing_keys.append(key)
71
+
72
+ if strict and len(missing_keys) > 0:
73
+ raise Exception(
74
+ f"{missing_keys} not found in either pre-trained models (e.g. BERT) or resumed checkpoints (e.g. epoch_40/model.pt)"
75
+ )
76
+ elif len(missing_keys) > 0:
77
+ print(f"Warning: missing keys {missing_keys}, skipping them.")
78
+
79
+ return super().load_state_dict(state_dict, strict)
models/content_adapter.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from utils.torch_utilities import concat_non_padding, restore_from_concat, create_mask_from_length
7
+ from models.content_encoder.content_encoder import ContentEncoder
8
+
9
+
10
+ ######################
11
+ # fastspeech modules
12
+ ######################
13
+ class LayerNorm(nn.LayerNorm):
14
+ """Layer normalization module.
15
+ :param int nout: output dim size
16
+ :param int dim: dimension to be normalized
17
+ """
18
+ def __init__(self, nout, dim=-1):
19
+ """Construct an LayerNorm object."""
20
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
21
+ self.dim = dim
22
+
23
+ def forward(self, x):
24
+ """Apply layer normalization.
25
+ :param torch.Tensor x: input tensor
26
+ :return: layer normalized tensor
27
+ :rtype torch.Tensor
28
+ """
29
+ if self.dim == -1:
30
+ return super(LayerNorm, self).forward(x)
31
+ return super(LayerNorm,
32
+ self).forward(x.transpose(1, -1)).transpose(1, -1)
33
+
34
+
35
+ class DurationPredictor(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels: int,
39
+ filter_channels: int,
40
+ n_layers: int = 2,
41
+ kernel_size: int = 3,
42
+ p_dropout: float = 0.1,
43
+ padding: str = "SAME"
44
+ ):
45
+ super(DurationPredictor, self).__init__()
46
+ self.conv = nn.ModuleList()
47
+ self.kernel_size = kernel_size
48
+ self.padding = padding
49
+ for idx in range(n_layers):
50
+ in_chans = in_channels if idx == 0 else filter_channels
51
+ self.conv += [
52
+ nn.Sequential(
53
+ nn.ConstantPad1d(((kernel_size - 1) // 2,
54
+ (kernel_size - 1) //
55
+ 2) if padding == 'SAME' else
56
+ (kernel_size - 1, 0), 0),
57
+ nn.Conv1d(
58
+ in_chans,
59
+ filter_channels,
60
+ kernel_size,
61
+ stride=1,
62
+ padding=0
63
+ ), nn.ReLU(), LayerNorm(filter_channels, dim=1),
64
+ nn.Dropout(p_dropout)
65
+ )
66
+ ]
67
+ self.linear = nn.Linear(filter_channels, 1)
68
+
69
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
70
+ # x: [B, T, E]
71
+ x = x.transpose(1, -1)
72
+ x_mask = x_mask.unsqueeze(1).to(x.device)
73
+ for f in self.conv:
74
+ x = f(x)
75
+ x = x * x_mask.float()
76
+
77
+ x = self.linear(x.transpose(1, -1)
78
+ ) * x_mask.transpose(1, -1).float() # [B, T, 1]
79
+ return x
80
+
81
+
82
+ ######################
83
+ # adapter modules
84
+ ######################
85
+
86
+
87
+ class ContentAdapterBase(nn.Module):
88
+ def __init__(self, d_out):
89
+ super().__init__()
90
+ self.d_out = d_out
91
+
92
+
93
+ class SinusoidalPositionalEmbedding(nn.Module):
94
+ def __init__(self, d_model, dropout, max_len=1000):
95
+ super().__init__()
96
+ self.dropout = nn.Dropout(dropout)
97
+ pe = torch.zeros(max_len, d_model)
98
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
99
+ div_term = torch.exp(
100
+ torch.arange(0, d_model, 2).float() *
101
+ (-math.log(10000.0) / d_model)
102
+ )
103
+ pe[:, 0::2] = torch.sin(position * div_term)
104
+ pe[:, 1::2] = torch.cos(position * div_term)
105
+ pe = pe.unsqueeze(0).transpose(0, 1)
106
+ self.register_buffer('pe', pe)
107
+
108
+ def forward(self, x):
109
+ x = x + self.pe[:x.size(1), :]
110
+ return self.dropout(x)
111
+
112
+
113
+ class ContentAdapter(ContentAdapterBase):
114
+ def __init__(
115
+ self,
116
+ d_model: int,
117
+ d_out: int,
118
+ num_layers: int,
119
+ num_heads: int,
120
+ duration_predictor: DurationPredictor,
121
+ dropout: float = 0.1,
122
+ norm_first: bool = False,
123
+ activation: str = "gelu",
124
+ duration_grad_scale: float = 0.0,
125
+ ):
126
+ super().__init__(d_out)
127
+ self.duration_grad_scale = duration_grad_scale
128
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
129
+ if hasattr(torch, "npu") and torch.npu.is_available():
130
+ enable_nested_tensor = False
131
+ else:
132
+ enable_nested_tensor = True
133
+ encoder_layer = nn.TransformerEncoderLayer(
134
+ d_model=d_model,
135
+ nhead=num_heads,
136
+ dim_feedforward=4 * d_model,
137
+ dropout=dropout,
138
+ activation=activation,
139
+ norm_first=norm_first,
140
+ batch_first=True
141
+ )
142
+ self.encoder_layers = nn.TransformerEncoder(
143
+ encoder_layer=encoder_layer,
144
+ num_layers=num_layers,
145
+ enable_nested_tensor=enable_nested_tensor
146
+ )
147
+ self.duration_predictor = duration_predictor
148
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
149
+
150
+ def forward(self, x, x_mask):
151
+ batch_size = x.size(0)
152
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
153
+ cls_embed = cls_embed.to(x.device).unsqueeze(1)
154
+ x = torch.cat([cls_embed, x], dim=1)
155
+
156
+ cls_mask = torch.ones(batch_size, 1).to(x_mask.device)
157
+ x_mask = torch.cat([cls_mask, x_mask], dim=1)
158
+ x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool())
159
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
160
+ ) * (1 - self.duration_grad_scale)
161
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
162
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
163
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
164
+
165
+
166
+ class PrefixAdapter(ContentAdapterBase):
167
+ def __init__(
168
+ self,
169
+ content_dim: int,
170
+ d_model: int,
171
+ d_out: int,
172
+ prefix_dim: int,
173
+ num_layers: int,
174
+ num_heads: int,
175
+ duration_predictor: DurationPredictor,
176
+ dropout: float = 0.1,
177
+ norm_first: bool = False,
178
+ use_last_norm: bool = True,
179
+ activation: str = "gelu",
180
+ duration_grad_scale: float = 0.1,
181
+ ):
182
+ super().__init__(d_out)
183
+ self.duration_grad_scale = duration_grad_scale
184
+ self.prefix_mlp = nn.Sequential(
185
+ nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
186
+ nn.Linear(d_model, d_model)
187
+ )
188
+ self.content_mlp = nn.Sequential(
189
+ nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
190
+ nn.Linear(d_model, d_model)
191
+ )
192
+ layer = nn.TransformerEncoderLayer(
193
+ d_model=d_model,
194
+ nhead=num_heads,
195
+ dim_feedforward=4 * d_model,
196
+ dropout=dropout,
197
+ activation=activation,
198
+ batch_first=True,
199
+ norm_first=norm_first
200
+ )
201
+ if hasattr(torch, "npu") and torch.npu.is_available():
202
+ enable_nested_tensor = False
203
+ else:
204
+ enable_nested_tensor = True
205
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
206
+ # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout)
207
+ self.layers = nn.TransformerEncoder(
208
+ encoder_layer=layer,
209
+ num_layers=num_layers,
210
+ enable_nested_tensor=enable_nested_tensor
211
+ )
212
+ self.use_last_norm = use_last_norm
213
+ if self.use_last_norm:
214
+ self.last_norm = nn.LayerNorm(d_model)
215
+ self.duration_predictor = duration_predictor
216
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
217
+ nn.init.normal_(self.cls_embed, 0., 0.02)
218
+ nn.init.xavier_uniform_(self.content_proj.weight)
219
+ nn.init.constant_(self.content_proj.bias, 0.)
220
+
221
+ def forward(self, content, content_mask, instruction, instruction_mask):
222
+ batch_size = content.size(0)
223
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
224
+ cls_embed = cls_embed.to(content.device).unsqueeze(1)
225
+ content = self.content_mlp(content)
226
+ x = torch.cat([cls_embed, content], dim=1)
227
+ cls_mask = torch.ones(batch_size, 1,
228
+ dtype=bool).to(content_mask.device)
229
+ x_mask = torch.cat([cls_mask, content_mask], dim=1)
230
+
231
+ prefix = self.prefix_mlp(instruction)
232
+ seq, seq_mask, perm = concat_non_padding(
233
+ prefix, instruction_mask, x, x_mask
234
+ )
235
+ # seq = self.pos_embed(seq)
236
+ x = self.layers(seq, src_key_padding_mask=~seq_mask.bool())
237
+ if self.use_last_norm:
238
+ x = self.last_norm(x)
239
+ _, x = restore_from_concat(x, instruction_mask, x_mask, perm)
240
+
241
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
242
+ ) * (1 - self.duration_grad_scale)
243
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
244
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
245
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
246
+
247
+
248
+ class CrossAttentionAdapter(ContentAdapterBase):
249
+ def __init__(
250
+ self,
251
+ d_out: int,
252
+ content_dim: int,
253
+ prefix_dim: int,
254
+ num_heads: int,
255
+ duration_predictor: DurationPredictor,
256
+ dropout: float = 0.1,
257
+ duration_grad_scale: float = 0.1,
258
+ ):
259
+ super().__init__(d_out)
260
+ self.attn = nn.MultiheadAttention(
261
+ embed_dim=content_dim,
262
+ num_heads=num_heads,
263
+ dropout=dropout,
264
+ kdim=prefix_dim,
265
+ vdim=prefix_dim,
266
+ batch_first=True,
267
+ )
268
+ self.duration_grad_scale = duration_grad_scale
269
+ self.duration_predictor = duration_predictor
270
+ self.global_duration_mlp = nn.Sequential(
271
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
272
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
273
+ )
274
+ self.norm = nn.LayerNorm(content_dim)
275
+ self.content_proj = nn.Conv1d(content_dim, d_out, 1)
276
+
277
+ def forward(self, content, content_mask, prefix, prefix_mask):
278
+ attn_output, attn_output_weights = self.attn(
279
+ query=content,
280
+ key=prefix,
281
+ value=prefix,
282
+ key_padding_mask=~prefix_mask.bool()
283
+ )
284
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
285
+ x = self.norm(attn_output + content)
286
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
287
+ ) * (1 - self.duration_grad_scale)
288
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
289
+ ).sum(dim=1) / content_mask.sum(dim=1,
290
+ keepdim=True).float()
291
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
292
+ local_duration = self.duration_predictor(
293
+ x_grad_rescaled, content_mask
294
+ ).squeeze(-1)
295
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
296
+ return content, content_mask, global_duration, local_duration
297
+
298
+
299
+ class ExperimentalCrossAttentionAdapter(ContentAdapterBase):
300
+ def __init__(
301
+ self,
302
+ d_out: int,
303
+ content_dim: int,
304
+ prefix_dim: int,
305
+ num_heads: int,
306
+ duration_predictor: DurationPredictor,
307
+ dropout: float = 0.1,
308
+ duration_grad_scale: float = 0.1,
309
+ ):
310
+ super().__init__(d_out)
311
+ self.content_mlp = nn.Sequential(
312
+ nn.Linear(content_dim, content_dim),
313
+ nn.ReLU(),
314
+ nn.Dropout(dropout),
315
+ nn.Linear(content_dim, content_dim),
316
+ )
317
+ self.content_norm = nn.LayerNorm(content_dim)
318
+ self.prefix_mlp = nn.Sequential(
319
+ nn.Linear(prefix_dim, prefix_dim),
320
+ nn.ReLU(),
321
+ nn.Dropout(dropout),
322
+ nn.Linear(prefix_dim, prefix_dim),
323
+ )
324
+ self.prefix_norm = nn.LayerNorm(content_dim)
325
+ self.attn = nn.MultiheadAttention(
326
+ embed_dim=content_dim,
327
+ num_heads=num_heads,
328
+ dropout=dropout,
329
+ kdim=prefix_dim,
330
+ vdim=prefix_dim,
331
+ batch_first=True,
332
+ )
333
+ self.duration_grad_scale = duration_grad_scale
334
+ self.duration_predictor = duration_predictor
335
+ self.global_duration_mlp = nn.Sequential(
336
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
337
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
338
+ )
339
+ self.content_proj = nn.Sequential(
340
+ nn.Linear(content_dim, d_out),
341
+ nn.ReLU(),
342
+ nn.Dropout(dropout),
343
+ nn.Linear(d_out, d_out),
344
+ )
345
+ self.norm1 = nn.LayerNorm(content_dim)
346
+ self.norm2 = nn.LayerNorm(d_out)
347
+ self.init_weights()
348
+
349
+ def init_weights(self):
350
+ def _init_weights(module):
351
+ if isinstance(module, nn.Linear):
352
+ nn.init.xavier_uniform_(module.weight)
353
+ if module.bias is not None:
354
+ nn.init.constant_(module.bias, 0.)
355
+
356
+ self.apply(_init_weights)
357
+
358
+ def forward(self, content, content_mask, prefix, prefix_mask):
359
+ content = self.content_mlp(content)
360
+ content = self.content_norm(content)
361
+ prefix = self.prefix_mlp(prefix)
362
+ prefix = self.prefix_norm(prefix)
363
+ attn_output, attn_weights = self.attn(
364
+ query=content,
365
+ key=prefix,
366
+ value=prefix,
367
+ key_padding_mask=~prefix_mask.bool(),
368
+ )
369
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
370
+ x = attn_output + content
371
+ x = self.norm1(x)
372
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
373
+ ) * (1 - self.duration_grad_scale)
374
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
375
+ ).sum(dim=1) / content_mask.sum(dim=1,
376
+ keepdim=True).float()
377
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
378
+ local_duration = self.duration_predictor(
379
+ x_grad_rescaled, content_mask
380
+ ).squeeze(-1)
381
+ content = self.content_proj(x)
382
+ content = self.norm2(content)
383
+ return content, content_mask, global_duration, local_duration
384
+
385
+
386
+ class ContentEncoderAdapterMixin:
387
+ def __init__(
388
+ self,
389
+ content_encoder: ContentEncoder,
390
+ content_adapter: ContentAdapterBase | None = None
391
+ ):
392
+ self.content_encoder = content_encoder
393
+ self.content_adapter = content_adapter
394
+
395
+ def encode_content(
396
+ self,
397
+ content: list[Any],
398
+ task: list[str],
399
+ device: str | torch.device,
400
+ instruction: torch.Tensor | None = None,
401
+ instruction_lengths: torch.Tensor | None = None
402
+ ):
403
+ content_output: dict[
404
+ str, torch.Tensor] = self.content_encoder.encode_content(
405
+ content, task, device=device
406
+ )
407
+ content, content_mask = content_output["content"], content_output[
408
+ "content_mask"]
409
+
410
+ if instruction is not None:
411
+ instruction_mask = create_mask_from_length(instruction_lengths)
412
+ (
413
+ content,
414
+ content_mask,
415
+ global_duration_pred,
416
+ local_duration_pred,
417
+ ) = self.content_adapter(
418
+ content, content_mask, instruction, instruction_mask
419
+ )
420
+
421
+ return_dict = {
422
+ "content": content,
423
+ "content_mask": content_mask,
424
+ "length_aligned_content": content_output["length_aligned_content"],
425
+ }
426
+ if instruction is not None:
427
+ return_dict["global_duration_pred"] = global_duration_pred
428
+ return_dict["local_duration_pred"] = local_duration_pred
429
+
430
+ return return_dict
models/content_encoder/__pycache__/content_encoder.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
models/content_encoder/__pycache__/llm_encoder.cpython-310.pyc ADDED
Binary file (6.16 kB). View file
 
models/content_encoder/content_encoder.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ContentEncoder(nn.Module):
7
+ def __init__(
8
+ self,
9
+ embed_dim: int,
10
+ text_encoder: nn.Module = None,
11
+ llm_encoder: nn.Module = None,
12
+ video_encoder: nn.Module = None,
13
+ midi_encoder: nn.Module = None,
14
+ phoneme_encoder: nn.Module = None,
15
+ pitch_encoder: nn.Module = None,
16
+ audio_encoder: nn.Module = None
17
+ ):
18
+ super().__init__()
19
+ self.embed_dim = embed_dim
20
+ self.text_encoder = text_encoder
21
+ self.midi_encoder = midi_encoder
22
+ self.phoneme_encoder = phoneme_encoder
23
+ self.pitch_encoder = pitch_encoder
24
+ self.audio_encoder = audio_encoder
25
+ self.video_encoder = video_encoder
26
+
27
+ def encode_content(
28
+ self, batch_content: list[Any], batch_task: list[str],
29
+ device: str | torch.device
30
+ ):
31
+ batch_content_output = []
32
+ batch_content_mask = []
33
+ batch_la_content_output = []
34
+ batch_la_content_output_mask = []
35
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
36
+
37
+ for i,(content, task) in enumerate(zip(batch_content, batch_task)):
38
+ if task == "audio_editing":
39
+ raw_waveform = torch.as_tensor(content["audio"]).float()
40
+ waveform_with_batch_dim = raw_waveform.unsqueeze(0).to(device)
41
+ waveform_lengths = torch.as_tensor([raw_waveform.shape[0]])
42
+
43
+ # Note: text encoder actually is audiollm encoder, encode both waveform and caption
44
+ content_output_dict = self.text_encoder(
45
+ [content["caption"]], waveform_with_batch_dim
46
+ )
47
+ audio_dict = {
48
+ "waveform": waveform_with_batch_dim,
49
+ "waveform_lengths": waveform_lengths
50
+ }
51
+ audio_output_dict = self.audio_encoder(**audio_dict)
52
+ la_content_output_dict = {
53
+ "output": audio_output_dict["output"],
54
+ "mask": audio_output_dict["mask"]
55
+ }
56
+
57
+ batch_content_output.append(content_output_dict["output"][0])
58
+ batch_content_mask.append(content_output_dict["mask"][0])
59
+ batch_la_content_output.append(la_content_output_dict["output"][0])
60
+ batch_la_content_output_mask.append(
61
+ la_content_output_dict.get("mask", zero_la_content)[0]
62
+ )
63
+
64
+ batch_content_output = nn.utils.rnn.pad_sequence(
65
+ batch_content_output, batch_first=True, padding_value=0
66
+ )
67
+ batch_content_mask = nn.utils.rnn.pad_sequence(
68
+ batch_content_mask, batch_first=True, padding_value=False
69
+ )
70
+ batch_la_content_output = nn.utils.rnn.pad_sequence(
71
+ batch_la_content_output, batch_first=True, padding_value=0
72
+ )
73
+
74
+ batch_la_content_output_mask = nn.utils.rnn.pad_sequence(
75
+ batch_la_content_output_mask, batch_first=True, padding_value=False
76
+ )
77
+ return {
78
+ "content": batch_content_output ,
79
+ "content_mask": batch_content_mask,
80
+ "length_aligned_content": batch_la_content_output,
81
+ "time_aligned_content_mask": batch_la_content_output_mask
82
+ }
83
+
84
+
85
+
86
+ class BatchedContentEncoder(ContentEncoder):
87
+ def encode_content(
88
+ self, batch_content: list[dict], batch_task: list[str],
89
+ device: str | torch.device
90
+ ):
91
+ assert all(task == "audio_editing" for task in batch_task), \
92
+ "BatchedContentEncoder now are only support audio_editing"
93
+
94
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
95
+
96
+ captions = []
97
+ waveforms = []
98
+ waveform_lengths = []
99
+ for content in batch_content:
100
+ raw_waveform = torch.as_tensor(content["audio"]).float().to(device)
101
+ captions.append(content["caption"])
102
+ waveforms.append(raw_waveform)
103
+ waveform_lengths.append(raw_waveform.shape[0])
104
+
105
+ content_output_dict = self.text_encoder(
106
+ captions, waveforms
107
+ )
108
+
109
+ batch_la_content_output = []
110
+ batch_la_content_output_mask = []
111
+ for i in range(len(batch_content)):
112
+ audio_dict = {
113
+ "waveform": waveforms[i].unsqueeze(0),
114
+ "waveform_lengths": torch.as_tensor([waveform_lengths[i]], device=device)
115
+ }
116
+ audio_output_dict = self.audio_encoder(**audio_dict)
117
+ batch_la_content_output.append(audio_output_dict["output"][0])
118
+ batch_la_content_output_mask.append(audio_output_dict["mask"][0])
119
+
120
+ # pad audio_encoder
121
+ batch_la_content_output = nn.utils.rnn.pad_sequence(
122
+ batch_la_content_output, batch_first=True, padding_value=0
123
+ )
124
+ batch_la_content_output_mask = nn.utils.rnn.pad_sequence(
125
+ batch_la_content_output_mask, batch_first=True, padding_value=False
126
+ )
127
+
128
+ return {
129
+ "content": content_output_dict["output"],
130
+ "content_mask": content_output_dict["mask"],
131
+ "length_aligned_content": batch_la_content_output,
132
+ "time_aligned_content_mask": batch_la_content_output_mask
133
+ }
models/content_encoder/llm_encoder.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import librosa
4
+ import numpy as np
5
+ from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
6
+ import os
7
+ # 暂未使用,原始应该是生成的pre
8
+ QWEN_AUDIO_PREFIX = '''Given a user prompt and an audio clip, generate an "Enhanced prompt" that provides detailed descriptions suitable for audio generation. Evaluate the audio and user prompt:
9
+ - If the prompt is simple, focus on adding specifics about tones, instruments, rhythms, tempos, and audio characteristics to create vivid and concrete audio descriptions.
10
+ - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
11
+ Here are examples of how to transform or refine prompts:
12
+ - User Prompt: Piano music -> Enhanced: A gentle, melancholic piano piece with delicate arpeggios in a minor key, featuring subtle reverb that creates a sense of space and intimacy.
13
+ - User Prompt: City sounds -> Enhanced: A bustling urban soundscape with distant traffic noise, occasional car horns, footsteps on concrete sidewalks, and the murmur of crowd conversations, with subtle pigeons cooing in the background.\n
14
+ Please generate only the enhanced description for the audio and prompt below and avoid including any additional commentary or evaluations:
15
+ User Prompt:'''
16
+
17
+ class Qwen2AudioEmbedder(nn.Module):
18
+ def __init__(self, model_path, embed_dim=256, max_length=320, dtype=torch.float, device="cuda"):
19
+ super().__init__()
20
+ self.max_length = max_length
21
+ self.device = device
22
+ self.embed_dim = embed_dim
23
+
24
+ self.model = Qwen2AudioForConditionalGeneration.from_pretrained(
25
+ model_path,
26
+ torch_dtype=dtype,
27
+ device_map={"": int(os.environ.get("LOCAL_RANK", 0))}
28
+ )
29
+ # 禁止梯度回传
30
+ self.model.requires_grad_(False)
31
+ self.model.eval()
32
+ self.processor = AutoProcessor.from_pretrained(model_path)
33
+
34
+ # 添加投影层,从模型隐藏层维度(4096)映射到指定的embed_dim
35
+ # 按理来说这一层也是会加入训练的呀
36
+ self.proj = nn.Linear(4096, embed_dim, device=device, dtype=dtype)
37
+ self.prefix = QWEN_AUDIO_PREFIX
38
+
39
+ def forward(self, text, audio_data):
40
+ """
41
+ Args:
42
+ text: 文本描述列表
43
+ audio_data: 音频数据列表,每个元素是numpy数组
44
+ Returns:
45
+ 字典包含 "output": 嵌入张量, "mask": 掩码张量
46
+ """
47
+ output, mask = self.encode(text, audio_data)
48
+ output = self.projection(output)
49
+ return {"output": output, "mask": mask}
50
+
51
+ def encode(self, text, audio_data):
52
+ """编码文本和音频到嵌入空间"""
53
+ """编码文本和音频到嵌入空间"""
54
+ batch_size = len(text)
55
+
56
+ # 统一转换采样率 (如果需要的话) - 这一步应该在外部或这里批量处理
57
+ processed_audios = []
58
+ for audio in audio_data:
59
+ if isinstance(audio, torch.Tensor):
60
+ audio = audio.cpu().numpy()
61
+ # 添加librosa.resample 操作
62
+ audio=librosa.resample(audio, orig_sr=24000, target_sr=16000)
63
+ processed_audios.append(audio)
64
+
65
+ # 批量构建对话文本
66
+ conversations = []
67
+ for txt in text:
68
+ conversation = [
69
+ {"role": "user", "content": [
70
+ # 注意:此处audio字段先用None占位,后面再由processor处理
71
+ {"type": "audio", "audio": None},
72
+ {"type": "text", "text": txt}
73
+ ]}
74
+ ]
75
+ # 使用 apply_chat_template 转换文本
76
+ formatted_text = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
77
+ conversations.append(formatted_text)
78
+
79
+ with torch.no_grad():
80
+ # 一次性批量处理整个batch的文本和音频
81
+ # processor会自动对音频数据进行填充
82
+ # padding的话是这里padding
83
+ inputs = self.processor(
84
+ text=conversations,
85
+ audio=processed_audios,
86
+ return_tensors="pt",
87
+ sampling_rate=16000,
88
+ padding=True,
89
+ truncation=True # 确保不会超过模型最大长度
90
+ )
91
+
92
+ # 将输入移动到设备
93
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
94
+
95
+ # 获取模型输出
96
+ outputs = self.model(
97
+ input_ids=inputs["input_ids"],
98
+ attention_mask=inputs["attention_mask"],
99
+ input_features=inputs["input_features"],
100
+ feature_attention_mask=inputs["feature_attention_mask"],
101
+ output_hidden_states=True,
102
+ )
103
+
104
+ # 提取最后一层隐藏状态
105
+ hidden_states_full = outputs.hidden_states[-1]
106
+
107
+ # 裁剪到最大长度
108
+ # ���量处理后,所有样本的长度都已对齐,所以可以直接切片
109
+ # embs = hidden_states_full[:, :self.max_length, :]
110
+ # masks = inputs["attention_mask"][:, :self.max_length].bool() # attention_mask可以直接作为布尔掩码使用
111
+
112
+
113
+ # --- 核心修改:确保输出长度固定为 self.max_length ---
114
+
115
+ # 1. 截断或填充隐藏状态
116
+ current_len = hidden_states_full.shape[1]
117
+ if current_len > self.max_length:
118
+ embs = hidden_states_full[:, :self.max_length, :]
119
+ else:
120
+ pad_width = self.max_length - current_len
121
+ # 创建一个(batch_size, pad_width, hidden_size)的零张量用于填充
122
+ padding = torch.zeros(
123
+ hidden_states_full.shape[0],
124
+ pad_width,
125
+ hidden_states_full.shape[2],
126
+ device=self.device,
127
+ dtype=hidden_states_full.dtype
128
+ )
129
+ embs = torch.cat([hidden_states_full, padding], dim=1)
130
+
131
+ # 2. 截断或填充掩码
132
+ attention_mask = inputs["attention_mask"]
133
+ if current_len > self.max_length:
134
+ masks = attention_mask[:, :self.max_length].bool()
135
+ else:
136
+ pad_width = self.max_length - current_len
137
+ # 创建一个(batch_size, pad_width)的False掩码
138
+ mask_padding = torch.zeros(
139
+ attention_mask.shape[0],
140
+ pad_width,
141
+ device=self.device,
142
+ dtype=torch.bool
143
+ )
144
+ masks = torch.cat([attention_mask.bool(), mask_padding], dim=1)
145
+
146
+ return embs, masks
147
+
148
+ def projection(self, x):
149
+ """将嵌入映射到指定维度"""
150
+ return self.proj(x)
151
+
152
+
153
+
154
+
155
+ if __name__ == "__main__":
156
+ import argparse
157
+
158
+ parser = argparse.ArgumentParser(description="Test Qwen Audio Encoder")
159
+ parser.add_argument("--model_path", type=str, default="/mnt/petrelfs/taoye/workspace/model/qwen25audio",
160
+ help="Path to Qwen Audio model")
161
+ parser.add_argument("--embed_dim", type=int, default=4096,
162
+ help="Target embedding dimension after projection")
163
+ args = parser.parse_args()
164
+
165
+ print(f"Loading model from {args.model_path}...")
166
+
167
+ # 初始化编码器
168
+ device = "cuda" if torch.cuda.is_available() else "cpu"
169
+ embedder = Qwen2AudioEmbedder(
170
+ model_path=args.model_path,
171
+ embed_dim=args.embed_dim,
172
+ max_length=640,
173
+ dtype=torch.float,
174
+ device=device
175
+ )
176
+
177
+ # 准备测试批次
178
+ captions = [
179
+ "Describe this audio",
180
+ "What musical instruments are being played in this recording?"
181
+ ]
182
+
183
+ # 直接加载音频数据
184
+ audio_path = "/mnt/petrelfs/taoye/workspace/editing/data/add/add_fore_audio_caps_begin_1/audio/edit/syn_5.wav"
185
+ audio_data = []
186
+ for _ in range(len(captions)):
187
+ waveform, sr = librosa.load(audio_path,sr=24000)
188
+ # print(sr)
189
+ audio_data.append(waveform)
190
+
191
+ # 获取嵌入
192
+ with torch.no_grad():
193
+ output = embedder(captions, audio_data)
194
+
195
+ # 打印结果
196
+ print("模型输出的字典:")
197
+ print(f"包含keys: {list(output.keys())}")
198
+
199
+ print("\n输出张量的形状:")
200
+ print(output['output'].shape)
201
+
202
+ print("\n掩码张量的形状:")
203
+ print(output['mask'].shape)
204
+
205
+ # 验证嵌入维度是否符合预期
206
+ assert output['output'].shape[-1] == args.embed_dim, f"输出维度 {output['output'].shape[-1]} 不等于预期维度 {args.embed_dim}"
207
+ print(f"\n成功验证:输出维度 = {args.embed_dim}")
208
+
209
+ # 显示样本嵌入值
210
+ print(f"样本嵌入值:\n{output['output'][0, :5, :5]}")
211
+ print(f"非零掩码位置数量: {output['mask'][0,:]}")
212
+ # 显示第一个样本中非零掩码位置的数量
213
+ print(f"第一个样本的非零掩码位置数量: {output['mask'][0].sum().item()}")
214
+
215
+
models/content_encoder/text_encoder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+
7
+ DEVICE_TYPE = "cuda"
8
+
9
+
10
+ class TransformersTextEncoderBase(nn.Module):
11
+ def __init__(self, model_name: str, embed_dim: int):
12
+ super().__init__()
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ self.model = AutoModel.from_pretrained(model_name)
15
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
16
+
17
+ def forward(
18
+ self,
19
+ text: list[str],
20
+ ):
21
+ output, mask = self.encode(text)
22
+ output = self.projection(output)
23
+ return {"output": output, "mask": mask}
24
+
25
+ def encode(self, text: list[str]):
26
+ device = self.model.device
27
+ batch = self.tokenizer(
28
+ text,
29
+ max_length=self.tokenizer.model_max_length,
30
+ padding=True,
31
+ truncation=True,
32
+ return_tensors="pt",
33
+ )
34
+ input_ids = batch.input_ids.to(device)
35
+ attention_mask = batch.attention_mask.to(device)
36
+ output: BaseModelOutput = self.model(
37
+ input_ids=input_ids, attention_mask=attention_mask
38
+ )
39
+ output = output.last_hidden_state
40
+ mask = (attention_mask == 1).to(device)
41
+ return output, mask
42
+
43
+ def projection(self, x):
44
+ return self.proj(x)
45
+
46
+
47
+ class T5TextEncoder(TransformersTextEncoderBase):
48
+ def __init__(
49
+ self, embed_dim: int, model_name: str = "google/flan-t5-large"
50
+ ):
51
+ nn.Module.__init__(self)
52
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
53
+ self.model = T5EncoderModel.from_pretrained(model_name)
54
+ for param in self.model.parameters():
55
+ param.requires_grad = False
56
+ self.model.eval()
57
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
58
+
59
+ def encode(
60
+ self,
61
+ text: list[str],
62
+ ):
63
+ with torch.no_grad(), torch.amp.autocast(
64
+ device_type=DEVICE_TYPE, enabled=False
65
+ ):
66
+ return super().encode(text)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ text_encoder = T5TextEncoder(embed_dim=512)
71
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
72
+
73
+ output = text_encoder(text)
74
+ print(output)
75
+ print(output['output'].shape)
76
+ print(output['mask'].shape)
models/diffusion.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import random
3
+ from typing import Any
4
+
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import diffusers.schedulers as noise_schedulers
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+
13
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
14
+ from models.content_encoder.content_encoder import ContentEncoder
15
+ from models.content_adapter import ContentAdapterBase, ContentEncoderAdapterMixin
16
+ import soundfile as sf
17
+ from models.common import (
18
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
19
+ )
20
+ from utils.torch_utilities import (
21
+ create_alignment_path, create_mask_from_length, loss_with_mask,
22
+ trim_or_pad_length
23
+ )
24
+
25
+
26
+ class DiffusionMixin:
27
+ def __init__(
28
+ self,
29
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
30
+ snr_gamma: float = None,
31
+ cfg_drop_ratio: float = 0.2
32
+ ) -> None:
33
+ self.noise_scheduler_name = noise_scheduler_name
34
+ self.snr_gamma = snr_gamma
35
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
36
+ self.cfg_drop_ratio = cfg_drop_ratio
37
+ self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained(
38
+ self.noise_scheduler_name, subfolder="scheduler"
39
+ )
40
+
41
+ def compute_snr(self, timesteps) -> torch.Tensor:
42
+ """
43
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
44
+ """
45
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
46
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
47
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5
48
+
49
+ # Expand the tensors.
50
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
51
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
52
+ )[timesteps].float()
53
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
54
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
55
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
56
+
57
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
58
+ device=timesteps.device
59
+ )[timesteps].float()
60
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
61
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
62
+ None]
63
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
64
+
65
+ # Compute SNR.
66
+ snr = (alpha / sigma)**2
67
+ return snr
68
+
69
+ def get_timesteps(
70
+ self,
71
+ batch_size: int,
72
+ device: torch.device,
73
+ training: bool = True
74
+ ) -> torch.Tensor:
75
+ if training:
76
+ timesteps = torch.randint(
77
+ 0,
78
+ self.noise_scheduler.config.num_train_timesteps,
79
+ (batch_size, ),
80
+ device=device
81
+ )
82
+ else:
83
+ # validation on half of the total timesteps
84
+ timesteps = (self.noise_scheduler.config.num_train_timesteps //
85
+ 2) * torch.ones((batch_size, ),
86
+ dtype=torch.int64,
87
+ device=device)
88
+
89
+ timesteps = timesteps.long()
90
+ return timesteps
91
+
92
+ def get_input_target_and_timesteps(
93
+ self,
94
+ latent: torch.Tensor,
95
+ training: bool,
96
+ ):
97
+ batch_size = latent.shape[0]
98
+ device = latent.device
99
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
100
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
101
+ timesteps = self.get_timesteps(batch_size, device, training=training)
102
+ noise = torch.randn_like(latent)
103
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
104
+ target = self.get_target(latent, noise, timesteps)
105
+ return noisy_latent, target, timesteps
106
+
107
+ def get_target(
108
+ self, latent: torch.Tensor, noise: torch.Tensor,
109
+ timesteps: torch.Tensor
110
+ ) -> torch.Tensor:
111
+ """
112
+ Get the target for loss depending on the prediction type
113
+ """
114
+ if self.noise_scheduler.config.prediction_type == "epsilon":
115
+ target = noise
116
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
117
+ target = self.noise_scheduler.get_velocity(
118
+ latent, noise, timesteps
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
123
+ )
124
+ return target
125
+
126
+ def loss_with_snr(
127
+ self,
128
+ pred: torch.Tensor,
129
+ target: torch.Tensor,
130
+ timesteps: torch.Tensor,
131
+ mask: torch.Tensor,
132
+ reduce: bool = True
133
+ ) -> torch.Tensor:
134
+ if self.snr_gamma is None:
135
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
136
+ loss = loss_with_mask(loss, mask, reduce=reduce)
137
+ else:
138
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
139
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006
140
+ snr = self.compute_snr(timesteps)
141
+ mse_loss_weights = torch.stack(
142
+ [
143
+ snr,
144
+ self.snr_gamma * torch.ones_like(timesteps),
145
+ ],
146
+ dim=1,
147
+ ).min(dim=1)[0]
148
+ # division by (snr + 1) does not work well, not clear about the reason
149
+ mse_loss_weights = mse_loss_weights / snr
150
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
151
+ loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
152
+ if reduce:
153
+ loss = loss.mean()
154
+ return loss
155
+
156
+ def rescale_cfg(
157
+ self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
158
+ guidance_rescale: float
159
+ ):
160
+ """
161
+ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
162
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
163
+ """
164
+ std_cond = pred_cond.std(
165
+ dim=list(range(1, pred_cond.ndim)), keepdim=True
166
+ )
167
+ std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
168
+
169
+ pred_rescaled = pred_cfg * (std_cond / std_cfg)
170
+ pred_cfg = guidance_rescale * pred_rescaled + (
171
+ 1 - guidance_rescale
172
+ ) * pred_cfg
173
+ return pred_cfg
174
+
175
+
176
+ class SingleTaskCrossAttentionAudioDiffusion(
177
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
178
+ DiffusionMixin, ContentEncoderAdapterMixin
179
+ ):
180
+ def __init__(
181
+ self,
182
+ autoencoder: AutoEncoderBase,
183
+ content_encoder: ContentEncoder,
184
+ backbone: nn.Module,
185
+ content_dim: int,
186
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
187
+ snr_gamma: float = None,
188
+ cfg_drop_ratio: float = 0.2,
189
+ ):
190
+ nn.Module.__init__(self)
191
+ DiffusionMixin.__init__(
192
+ self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
193
+ )
194
+ ContentEncoderAdapterMixin.__init__(
195
+ self, content_encoder=content_encoder
196
+ )
197
+ self.autoencoder = autoencoder
198
+ for param in self.autoencoder.parameters():
199
+ param.requires_grad = False
200
+
201
+ if hasattr(self.content_encoder, "audio_encoder"):
202
+ self.content_encoder.audio_encoder.model = self.autoencoder
203
+
204
+ self.backbone = backbone
205
+ self.dummy_param = nn.Parameter(torch.empty(0))
206
+
207
+ def forward(
208
+ self, content: list[Any], task: list[str],
209
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
210
+ ):
211
+ device = self.dummy_param.device
212
+
213
+ self.autoencoder.eval()
214
+ self.content_encoder.eval()
215
+ with torch.no_grad():
216
+
217
+ latent, latent_mask = self.autoencoder.encode(
218
+ waveform.unsqueeze(1), waveform_lengths,pad_latent_len=500
219
+ )
220
+
221
+ with torch.no_grad():
222
+ content_dict = self.content_encoder.encode_content(content, task, device)
223
+ context, context_mask = content_dict["content"], content_dict[
224
+ "content_mask"]
225
+ time_aligned_content = content_dict["length_aligned_content"]
226
+ time_aligned_content_mask = content_dict[
227
+ "time_aligned_content_mask"
228
+ ]
229
+ latent_mask = time_aligned_content_mask.to(device)
230
+
231
+ if self.training and self.classifier_free_guidance:
232
+ mask_indices = [
233
+ k for k in range(len(waveform))
234
+ if random.random() < self.cfg_drop_ratio
235
+ ]
236
+ if len(mask_indices) > 0:
237
+ context[mask_indices] = 0
238
+ # dont mask!
239
+ # time_aligned_content[mask_indices] = 0
240
+
241
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
242
+ latent, self.training
243
+ )
244
+
245
+ pred: torch.Tensor = self.backbone(
246
+ x=noisy_latent,
247
+ timesteps=timesteps,
248
+ time_aligned_context=time_aligned_content,
249
+ context=context,
250
+ x_mask=latent_mask,
251
+ context_mask=context_mask
252
+ )
253
+
254
+ pred = pred.transpose(1, self.autoencoder.time_dim)
255
+ target = target.transpose(1, self.autoencoder.time_dim)
256
+ loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
257
+
258
+ return loss
259
+
260
+ def prepare_latent(
261
+ self, batch_size: int, scheduler: SchedulerMixin,
262
+ latent_shape: Sequence[int], dtype: torch.dtype, device: str
263
+ ):
264
+ shape = (batch_size, *latent_shape)
265
+ latent = randn_tensor(
266
+ shape, generator=None, device=device, dtype=dtype
267
+ )
268
+ # scale the initial noise by the standard deviation required by the scheduler
269
+ latent = latent * scheduler.init_noise_sigma
270
+ return latent
271
+
272
+ def iterative_denoise(
273
+ self,
274
+ latent: torch.Tensor,
275
+ scheduler: SchedulerMixin,
276
+ verbose: bool,
277
+ cfg: bool,
278
+ cfg_scale: float,
279
+ cfg_rescale: float,
280
+ backbone_input: dict,
281
+ ):
282
+ timesteps = scheduler.timesteps
283
+ num_steps = len(timesteps)
284
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
285
+ progress_bar = tqdm(range(num_steps), disable=not verbose)
286
+
287
+ for i, timestep in enumerate(timesteps):
288
+ # expand the latent if we are doing classifier free guidance
289
+ if cfg:
290
+ latent_input = torch.cat([latent, latent])
291
+ else:
292
+ latent_input = latent
293
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
294
+ # print(latent_input.shape)
295
+ noise_pred = self.backbone(
296
+ x=latent_input, timesteps=timestep, **backbone_input
297
+ )
298
+
299
+ # perform guidance
300
+ if cfg:
301
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
302
+ noise_pred = noise_pred_uncond + cfg_scale * (
303
+ noise_pred_content - noise_pred_uncond
304
+ )
305
+ if cfg_rescale != 0.0:
306
+ noise_pred = self.rescale_cfg(
307
+ noise_pred_content, noise_pred, cfg_rescale
308
+ )
309
+
310
+ # compute the previous noisy sample x_t -> x_t-1
311
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
312
+
313
+ # call the callback, if provided
314
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
315
+ (i + 1) % scheduler.order == 0):
316
+ progress_bar.update(1)
317
+
318
+ progress_bar.close()
319
+
320
+ return latent
321
+
322
+ @torch.no_grad()
323
+ def inference(
324
+ self,
325
+ content: list[Any],
326
+ task: list[str],
327
+ scheduler: SchedulerMixin,
328
+ num_steps: int = 50,
329
+ guidance_scale: float = 3.0,
330
+ guidance_rescale: float = 0.0,
331
+ disable_progress: bool = True,
332
+ mask_time_aligned_content: bool = True, # 新增参数
333
+ **kwargs
334
+ ):
335
+ device = self.dummy_param.device
336
+ classifier_free_guidance = guidance_scale > 1.0
337
+ batch_size = len(content)
338
+
339
+
340
+ content_dict = self.content_encoder.encode_content(content, task, device)
341
+
342
+
343
+ context, context_mask = content_dict["content"], content_dict[
344
+ "content_mask"]
345
+ time_aligned_content = content_dict["length_aligned_content"]
346
+ time_aligned_content_mask = content_dict[
347
+ "time_aligned_content_mask"
348
+ ]
349
+
350
+
351
+
352
+ B, T, C = time_aligned_content.shape
353
+ latent_shape = (C, T) # 128, 500
354
+ latent_mask=time_aligned_content_mask.to(device)
355
+
356
+
357
+
358
+ if classifier_free_guidance:
359
+
360
+
361
+ if mask_time_aligned_content:
362
+ uncond_time_aligned_content = torch.zeros_like(time_aligned_content)
363
+ else:
364
+ uncond_time_aligned_content = time_aligned_content.detach().clone()
365
+
366
+ uncond_context = torch.zeros_like(context)
367
+ uncond_context_mask = context_mask.detach().clone()
368
+ time_aligned_content = torch.cat([
369
+ uncond_time_aligned_content, time_aligned_content
370
+ ])
371
+ context = torch.cat([uncond_context, context])
372
+ context_mask = torch.cat([uncond_context_mask, context_mask])
373
+ latent_mask = torch.cat([
374
+ latent_mask, latent_mask.detach().clone()
375
+ ])
376
+
377
+ scheduler.set_timesteps(num_steps, device=device)
378
+
379
+ latent = self.prepare_latent(
380
+ batch_size, scheduler, latent_shape, context.dtype, device
381
+ )
382
+
383
+ latent = self.iterative_denoise(
384
+ latent=latent,
385
+ scheduler=scheduler,
386
+ verbose=not disable_progress,
387
+ cfg=classifier_free_guidance,
388
+ cfg_scale=guidance_scale,
389
+ cfg_rescale=guidance_rescale,
390
+ backbone_input={
391
+ "x_mask": latent_mask,
392
+ "context": context,
393
+ "context_mask": context_mask,
394
+ "time_aligned_context": time_aligned_content,
395
+ }
396
+ )
397
+ waveform = self.autoencoder.decode(latent,latent_mask)
398
+
399
+ return waveform
400
+
401
+
models/dit/__init__.py ADDED
File without changes
models/dit/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
models/dit/__pycache__/mmdit_back.cpython-310.pyc ADDED
Binary file (8.63 kB). View file
 
models/dit/__pycache__/mmdit_layers.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
models/dit/__pycache__/modules.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
models/dit/attention.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
12
+ ATTENTION_MODE = 'flash'
13
+ else:
14
+ ATTENTION_MODE = 'math'
15
+ print(f'attention mode is {ATTENTION_MODE}')
16
+
17
+
18
+ def add_mask(sim, mask):
19
+ b, ndim = sim.shape[0], mask.ndim
20
+ if ndim == 3:
21
+ mask = rearrange(mask, "b n m -> b 1 n m")
22
+ if ndim == 2:
23
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
24
+ max_neg_value = -torch.finfo(sim.dtype).max
25
+ sim = sim.masked_fill(~mask, max_neg_value)
26
+ return sim
27
+
28
+
29
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
30
+ def default(val, d):
31
+ return val if val is not None else (d() if isfunction(d) else d)
32
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
33
+ q_mask = default(
34
+ q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
35
+ )
36
+ k_mask = default(
37
+ k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
38
+ )
39
+ k_mask = k_mask.to(device)
40
+ q_mask = q_mask.to(device)
41
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
42
+ ) * rearrange(k_mask, 'b j -> b 1 1 j')
43
+ return attn_mask
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ context_dim=None,
51
+ num_heads=8,
52
+ qkv_bias=False,
53
+ qk_scale=None,
54
+ qk_norm=None,
55
+ attn_drop=0.,
56
+ proj_drop=0.,
57
+ rope_mode='none'
58
+ ):
59
+ super().__init__()
60
+ self.num_heads = num_heads
61
+ head_dim = dim // num_heads
62
+ self.scale = qk_scale or head_dim**-0.5
63
+
64
+ if context_dim is None:
65
+ self.cross_attn = False
66
+ else:
67
+ self.cross_attn = True
68
+
69
+ context_dim = dim if context_dim is None else context_dim
70
+
71
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
72
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
73
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
74
+
75
+ if qk_norm is None:
76
+ self.norm_q = nn.Identity()
77
+ self.norm_k = nn.Identity()
78
+ elif qk_norm == 'layernorm':
79
+ self.norm_q = nn.LayerNorm(head_dim)
80
+ self.norm_k = nn.LayerNorm(head_dim)
81
+ elif qk_norm == 'rmsnorm':
82
+ self.norm_q = RMSNorm(head_dim)
83
+ self.norm_k = RMSNorm(head_dim)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ self.attn_drop_p = attn_drop
88
+ self.attn_drop = nn.Dropout(attn_drop)
89
+ self.proj = nn.Linear(dim, dim)
90
+ self.proj_drop = nn.Dropout(proj_drop)
91
+
92
+ if self.cross_attn:
93
+ assert rope_mode == 'none'
94
+ self.rope_mode = rope_mode
95
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
96
+ self.rotary = RotaryEmbedding(dim=head_dim)
97
+ elif self.rope_mode == 'dual':
98
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
99
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
100
+
101
+ def _rotary(self, q, k, extras):
102
+ if self.rope_mode == 'shared':
103
+ q, k = self.rotary(q=q, k=k)
104
+ elif self.rope_mode == 'x_only':
105
+ q_x, k_x = self.rotary(
106
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
107
+ )
108
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
109
+ q = torch.cat((q_c, q_x), dim=2)
110
+ k = torch.cat((k_c, k_x), dim=2)
111
+ elif self.rope_mode == 'dual':
112
+ q_x, k_x = self.rotary_x(
113
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
114
+ )
115
+ q_c, k_c = self.rotary_c(
116
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
117
+ )
118
+ q = torch.cat((q_c, q_x), dim=2)
119
+ k = torch.cat((k_c, k_x), dim=2)
120
+ elif self.rope_mode == 'none':
121
+ pass
122
+ else:
123
+ raise NotImplementedError
124
+ return q, k
125
+
126
+ def _attn(self, q, k, v, mask_binary):
127
+ if ATTENTION_MODE == 'flash':
128
+ x = F.scaled_dot_product_attention(
129
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
130
+ )
131
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
132
+ elif ATTENTION_MODE == 'math':
133
+ attn = (q @ k.transpose(-2, -1)) * self.scale
134
+ attn = add_mask(
135
+ attn, mask_binary
136
+ ) if mask_binary is not None else attn
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+ x = (attn @ v).transpose(1, 2)
140
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
141
+ else:
142
+ raise NotImplementedError
143
+ return x
144
+
145
+ def forward(self, x, context=None, context_mask=None, extras=0):
146
+ B, L, C = x.shape
147
+ if context is None:
148
+ context = x
149
+
150
+ q = self.to_q(x)
151
+ k = self.to_k(context)
152
+ v = self.to_v(context)
153
+
154
+ if context_mask is not None:
155
+ mask_binary = create_mask(
156
+ x.shape, context.shape, x.device, None, context_mask
157
+ )
158
+ else:
159
+ mask_binary = None
160
+
161
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
162
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
163
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
164
+
165
+ q = self.norm_q(q)
166
+ k = self.norm_k(k)
167
+
168
+ q, k = self._rotary(q, k, extras)
169
+
170
+ x = self._attn(q, k, v, mask_binary)
171
+
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class JointAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim,
181
+ num_heads=8,
182
+ qkv_bias=False,
183
+ qk_scale=None,
184
+ qk_norm=None,
185
+ attn_drop=0.,
186
+ proj_drop=0.,
187
+ rope_mode='none'
188
+ ):
189
+ super().__init__()
190
+ self.num_heads = num_heads
191
+ head_dim = dim // num_heads
192
+ self.scale = qk_scale or head_dim**-0.5
193
+
194
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
195
+ dim, qkv_bias
196
+ )
197
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
198
+ dim, qkv_bias
199
+ )
200
+
201
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
202
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
203
+
204
+ self.attn_drop_p = attn_drop
205
+ self.attn_drop = nn.Dropout(attn_drop)
206
+
207
+ self.proj_x = nn.Linear(dim, dim)
208
+ self.proj_drop_x = nn.Dropout(proj_drop)
209
+
210
+ self.proj_c = nn.Linear(dim, dim)
211
+ self.proj_drop_c = nn.Dropout(proj_drop)
212
+
213
+ self.rope_mode = rope_mode
214
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
215
+ self.rotary = RotaryEmbedding(dim=head_dim)
216
+ elif self.rope_mode == 'dual':
217
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
218
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
219
+
220
+ def _make_qkv_layers(self, dim, qkv_bias):
221
+ return (
222
+ nn.Linear(dim, dim,
223
+ bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
224
+ nn.Linear(dim, dim, bias=qkv_bias)
225
+ )
226
+
227
+ def _make_norm_layers(self, qk_norm, head_dim):
228
+ if qk_norm is None:
229
+ norm_q = nn.Identity()
230
+ norm_k = nn.Identity()
231
+ elif qk_norm == 'layernorm':
232
+ norm_q = nn.LayerNorm(head_dim)
233
+ norm_k = nn.LayerNorm(head_dim)
234
+ elif qk_norm == 'rmsnorm':
235
+ norm_q = RMSNorm(head_dim)
236
+ norm_k = RMSNorm(head_dim)
237
+ else:
238
+ raise NotImplementedError
239
+ return norm_q, norm_k
240
+
241
+ def _rotary(self, q, k, extras):
242
+ if self.rope_mode == 'shared':
243
+ q, k = self.rotary(q=q, k=k)
244
+ elif self.rope_mode == 'x_only':
245
+ q_x, k_x = self.rotary(
246
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
247
+ )
248
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
249
+ q = torch.cat((q_c, q_x), dim=2)
250
+ k = torch.cat((k_c, k_x), dim=2)
251
+ elif self.rope_mode == 'dual':
252
+ q_x, k_x = self.rotary_x(
253
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
254
+ )
255
+ q_c, k_c = self.rotary_c(
256
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
257
+ )
258
+ q = torch.cat((q_c, q_x), dim=2)
259
+ k = torch.cat((k_c, k_x), dim=2)
260
+ elif self.rope_mode == 'none':
261
+ pass
262
+ else:
263
+ raise NotImplementedError
264
+ return q, k
265
+
266
+ def _attn(self, q, k, v, mask_binary):
267
+ if ATTENTION_MODE == 'flash':
268
+ x = F.scaled_dot_product_attention(
269
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
270
+ )
271
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
272
+ elif ATTENTION_MODE == 'math':
273
+ attn = (q @ k.transpose(-2, -1)) * self.scale
274
+ attn = add_mask(
275
+ attn, mask_binary
276
+ ) if mask_binary is not None else attn
277
+ attn = attn.softmax(dim=-1)
278
+ attn = self.attn_drop(attn)
279
+ x = (attn @ v).transpose(1, 2)
280
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
281
+ else:
282
+ raise NotImplementedError
283
+ return x
284
+
285
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
286
+ B = x.shape[0]
287
+ if x_mask is None:
288
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
289
+ if context_mask is None:
290
+ context_mask = torch.ones(
291
+ B, context.shape[-2], device=context.device
292
+ ).bool()
293
+ mask = torch.cat([context_mask, x_mask], dim=1)
294
+ return mask
295
+
296
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
297
+ B, Lx, C = x.shape
298
+ _, Lc, _ = context.shape
299
+ if x_mask is not None or context_mask is not None:
300
+ mask = self._cat_mask(
301
+ x, context, x_mask=x_mask, context_mask=context_mask
302
+ )
303
+ shape = [B, Lx + Lc, C]
304
+ mask_binary = create_mask(
305
+ q_shape=shape,
306
+ k_shape=shape,
307
+ device=x.device,
308
+ q_mask=None,
309
+ k_mask=mask
310
+ )
311
+ else:
312
+ mask_binary = None
313
+
314
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
315
+ qc, kc, vc = self.to_qc(context), self.to_kc(context
316
+ ), self.to_vc(context)
317
+
318
+ qx, kx, vx = map(
319
+ lambda t: einops.
320
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
321
+ [qx, kx, vx]
322
+ )
323
+ qc, kc, vc = map(
324
+ lambda t: einops.
325
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
326
+ [qc, kc, vc]
327
+ )
328
+
329
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
330
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
331
+
332
+ q, k, v = (
333
+ torch.cat([qc, qx],
334
+ dim=2), torch.cat([kc, kx],
335
+ dim=2), torch.cat([vc, vx], dim=2)
336
+ )
337
+
338
+ q, k = self._rotary(q, k, extras)
339
+
340
+ x = self._attn(q, k, v, mask_binary)
341
+
342
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
343
+
344
+ x = self.proj_x(x)
345
+ x = self.proj_drop_x(x)
346
+
347
+ context = self.proj_c(context)
348
+ context = self.proj_drop_c(context)
349
+
350
+ return x, context
models/dit/mmdit_back.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ # 假设这些是你原来的导入
10
+ from .mmdit_layers import compute_rope_rotations
11
+ from .mmdit_layers import TimestepEmbedder
12
+ from .mmdit_layers import MLP, ChannelLastConv1d, ConvMLP
13
+ from .mmdit_layers import (FinalBlock, MMDitSingleBlock, JointBlock_AT)
14
+
15
+ log = logging.getLogger()
16
+
17
+
18
+ @dataclass
19
+ class PreprocessedConditions:
20
+ text_f: torch.Tensor
21
+ text_f_c: torch.Tensor
22
+
23
+
24
+ class MMAudio(nn.Module):
25
+ """
26
+ 一个修改版的 MMAudio 接口尽量和LayerFusionAudioDiT一致。
27
+ """
28
+ def __init__(self,
29
+ *,
30
+ latent_dim: int,
31
+ text_dim: int,
32
+ hidden_dim: int,
33
+ depth: int,
34
+ fused_depth: int,
35
+ num_heads: int,
36
+ mlp_ratio: float = 4.0,
37
+ latent_seq_len: int,
38
+ text_seq_len: int = 640,
39
+ # --- 新增参数,对齐 LayerFusionAudioDiT ---
40
+ ta_context_dim: int,
41
+ ta_context_fusion: str = 'add', # 'add' or 'concat'
42
+ ta_context_norm: bool = False,
43
+ # --- 其他原有参数 ---
44
+ empty_string_feat: Optional[torch.Tensor] = None,
45
+ v2: bool = False) -> None:
46
+ super().__init__()
47
+
48
+ self.v2 = v2
49
+ self.latent_dim = latent_dim
50
+ self._latent_seq_len = latent_seq_len
51
+ self._text_seq_len = text_seq_len
52
+ self.hidden_dim = hidden_dim
53
+ self.num_heads = num_heads
54
+
55
+ # --- 1. time_aligned_context 的投影层 ---
56
+ # 我们在这里定义一个投影层,而不是在每个 block 里都定义一个。
57
+ # 这样更高效,也符合你代码注释中的想法:“现在是每一层proj,改为不映射”。
58
+ # 我们的方案是:只映射一次,然后传递给所有层。
59
+ self.ta_context_fusion = ta_context_fusion
60
+ self.ta_context_norm_flag = ta_context_norm
61
+
62
+ if self.ta_context_fusion == "add":
63
+ # 如果是相加融合,将 ta_context 投射到和 latent 一样的维度 (hidden_dim)
64
+ self.ta_context_projection = nn.Linear(ta_context_dim, hidden_dim, bias=False)
65
+ self.ta_context_norm = nn.LayerNorm(ta_context_dim) if self.ta_context_norm_flag else nn.Identity()
66
+ elif self.ta_context_fusion == "concat":
67
+ # 如果是拼接融合,在 block 内部处理,这里不需要主投影层
68
+ # 但你的原始代码在concat后也有一个projection,我们可以在 block 内部实现
69
+ # 为了简化,这里先假设主要的融合逻辑在 block 内部
70
+ self.ta_context_projection = nn.Identity()
71
+ self.ta_context_norm = nn.Identity()
72
+ else:
73
+ raise ValueError(f"Unknown ta_context_fusion type: {ta_context_fusion}")
74
+
75
+
76
+ # --- 原有的输入投影层 (基本不变) ---
77
+ # 现在我的输入要变为editing,需要变为latent*2
78
+ self.audio_input_proj = nn.Sequential(
79
+ ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=7, padding=3),
80
+ nn.SELU(),
81
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
82
+ )
83
+ self.text_input_proj = nn.Sequential(
84
+ nn.Linear(text_dim, hidden_dim),
85
+ MLP(hidden_dim, hidden_dim * 4),
86
+ )
87
+
88
+ self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim)
89
+ self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
90
+
91
+
92
+ #
93
+ self.t_embed = TimestepEmbedder(hidden_dim, frequency_embedding_size=256, max_period=10000)
94
+
95
+ # --- Transformer Blocks (基本不变) ---
96
+ # **重要**: 你需要修改 JointBlock_AT 和 MMDitSingleBlock 的 forward 定义来接收 `time_aligned_context`
97
+ self.joint_blocks = nn.ModuleList([
98
+ JointBlock_AT(hidden_dim, num_heads, mlp_ratio=mlp_ratio, pre_only=(i == depth - fused_depth - 1))
99
+ for i in range(depth - fused_depth)
100
+ ])
101
+ self.fused_blocks = nn.ModuleList([
102
+ MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1)
103
+ for i in range(fused_depth)
104
+ ])
105
+
106
+ # --- 输出层 (不变) ---
107
+ self.final_layer = FinalBlock(hidden_dim, latent_dim)
108
+
109
+
110
+ if empty_string_feat is None:
111
+ empty_string_feat = torch.zeros((text_seq_len, text_dim))
112
+
113
+ self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
114
+
115
+ self.initialize_weights()
116
+ self.initialize_rotations()
117
+
118
+ def initialize_rotations(self):
119
+ base_freq = 1.0
120
+
121
+ # 唯一需要用到长度的
122
+ latent_rot = compute_rope_rotations(self._latent_seq_len,
123
+ self.hidden_dim // self.num_heads,
124
+ 10000,
125
+ freq_scaling=base_freq,
126
+ device="cuda" if torch.cuda.is_available() else "cpu")
127
+
128
+ # add to model buffers
129
+ self.register_buffer('latent_rot', latent_rot, persistent=False)
130
+ # self.clip_rot = nn.Buffer(clip_rot, persistent=False)
131
+
132
+ def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
133
+ self._latent_seq_len = latent_seq_len
134
+ self._clip_seq_len = clip_seq_len
135
+ self._sync_seq_len = sync_seq_len
136
+ self.initialize_rotations()
137
+
138
+ def initialize_weights(self):
139
+
140
+ def _basic_init(module):
141
+ if isinstance(module, nn.Linear):
142
+ torch.nn.init.xavier_uniform_(module.weight)
143
+ if module.bias is not None:
144
+ nn.init.constant_(module.bias, 0)
145
+
146
+ self.apply(_basic_init)
147
+
148
+ # Initialize timestep embedding MLP:
149
+ nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
150
+ nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
151
+
152
+ # Zero-out adaLN modulation layers in DiT blocks:兼容性保护
153
+ for block in self.joint_blocks:
154
+ nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
155
+ nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
156
+ nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
157
+ nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
158
+ for block in self.fused_blocks:
159
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
+
162
+ # Zero-out output layers:
163
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
164
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
165
+ nn.init.constant_(self.final_layer.conv.weight, 0)
166
+ nn.init.constant_(self.final_layer.conv.bias, 0)
167
+
168
+
169
+
170
+ def preprocess_conditions(self, text_f: torch.Tensor) -> PreprocessedConditions:
171
+ # 预处理文本条件
172
+ # assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
173
+ bs = text_f.shape[0]
174
+
175
+ # 这里固定外部的llm_embedding
176
+ text_f = self.text_input_proj(text_f)
177
+ # 全局的条件
178
+ text_f_c = self.text_cond_proj(text_f.mean(dim=1))
179
+ return PreprocessedConditions(text_f=text_f, text_f_c=text_f_c)
180
+
181
+ def predict_flow(self, x: torch.Tensor, timesteps: torch.Tensor,
182
+ conditions: PreprocessedConditions,
183
+ time_aligned_context: torch.Tensor) -> torch.Tensor:
184
+ """
185
+ 核心的预测流程,现在加入了 time_aligned_context。
186
+ """
187
+ assert x.shape[2] == self._latent_seq_len, f'{x.shape=} {self._latent_seq_len=}'
188
+
189
+ # 1. 预处理各种输入
190
+ text_f = conditions.text_f
191
+ text_f_c = conditions.text_f_c
192
+
193
+ timesteps = timesteps.to(x.dtype) # 保持和输入张量同 dtype
194
+
195
+ global_c = self.global_cond_mlp(text_f_c) # (B, D)
196
+
197
+ # 2. 融合 timestep
198
+ global_c = self.t_embed(timesteps).unsqueeze(1) + global_c.unsqueeze(1) # (B, 1, D)
199
+ extended_c = global_c # 这个将作为 AdaLN 的条件
200
+ """
201
+ 这里决定了x的形状,需要debug
202
+ """
203
+ # 3. **处理 time_aligned_context** 这里第一种方式是直接和latent进行融合,然后投影
204
+ # 从128->256
205
+ x = torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
206
+ latent = self.audio_input_proj(x) # (B, N, D)
207
+
208
+ # 4. 依次通过 Transformer Blocks
209
+ for block in self.joint_blocks:
210
+ # **你需要修改 JointBlock_AT.forward**
211
+ latent, text_f = block(latent, text_f, global_c, extended_c,
212
+ self.latent_rot)
213
+
214
+ for block in self.fused_blocks:
215
+ # **你需要修改 MMDitSingleBlock.forward**
216
+ latent = block(latent, extended_c, self.latent_rot)
217
+
218
+ # 5. 通过输出层
219
+ flow = self.final_layer(latent, global_c)
220
+ return flow
221
+
222
+ def forward(self,
223
+ x: torch.Tensor,
224
+ timesteps: torch.Tensor,
225
+ context: torch.Tensor,
226
+ time_aligned_context: torch.Tensor,
227
+ x_mask=None,
228
+ context_mask=None,
229
+ ) -> torch.Tensor:
230
+ """
231
+ 模型主入口,接口已对齐 LayerFusionAudioDiT。
232
+ - x: 噪声 latent, shape (B, N_latent, latent_dim)
233
+ - timesteps: 时间步, shape (B,)
234
+ - context: 文本条件, shape (B, N_text, text_dim)
235
+ - time_aligned_context: 时间对齐的条件, shape (B, N_ta, ta_context_dim)
236
+ """
237
+
238
+ if timesteps.dim() == 0:
239
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
240
+
241
+ text_conditions = self.preprocess_conditions(context)
242
+
243
+ # 调用核心预测流
244
+ flow = self.predict_flow(x, timesteps, text_conditions, time_aligned_context)
245
+
246
+
247
+ flow = flow.transpose(1, 2)
248
+
249
+
250
+
251
+
252
+ return flow
253
+
254
+
255
+
256
+ @property
257
+ def latent_seq_len(self) -> int:
258
+ return self._latent_seq_len
259
+
260
+
261
+ # latent(b,500,128)
262
+
263
+ def small_16k(**kwargs) -> MMAudio:
264
+ num_heads = 16
265
+ return MMAudio(latent_dim=128,
266
+ text_dim=1024,
267
+ hidden_dim=64 * num_heads,
268
+ depth=12,
269
+ fused_depth=8,
270
+ num_heads=num_heads,
271
+ latent_seq_len=500,
272
+ **kwargs)
273
+
274
+
275
+
276
+
277
+ if __name__ == '__main__':
278
+
279
+ batch_size = 4
280
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
281
+ print(f"Using device: {device}")
282
+
283
+
284
+ config = {
285
+ "ta_context_dim": 128,
286
+ "ta_context_fusion": "concat",
287
+ "ta_context_norm": False
288
+ }
289
+
290
+
291
+ try:
292
+ model = small_16k(**config).to(device)
293
+ model.eval() # 使用评估模式
294
+ print("Model instantiated successfully!")
295
+ except Exception as e:
296
+ print(f"Error during model instantiation: {e}")
297
+ exit()
298
+
299
+
300
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
301
+ print(f'Number of parameters: {num_params:.2f}M')
302
+
303
+
304
+ latent_dim = 128
305
+ latent_seq_len = 500
306
+ text_dim = 1024
307
+ #
308
+ text_seq_len = 640
309
+ ta_context_dim = config["ta_context_dim"]
310
+
311
+ dummy_x = torch.randn(batch_size,latent_dim, latent_seq_len, device=device)
312
+ dummy_timesteps = torch.randint(0, 1000, (batch_size,), device=device)
313
+ dummy_context = torch.randn(batch_size, text_seq_len, text_dim, device=device)
314
+
315
+ # 这里的 time_aligned_context 形状需要和 x 一致,以便在特征维度上拼接
316
+ dummy_ta_context = torch.randn(batch_size, latent_seq_len, ta_context_dim, device=device)
317
+
318
+ print("\n--- Input Shapes ---")
319
+ print(f"x (latent): {dummy_x.shape}")
320
+ print(f"timesteps: {dummy_timesteps.shape}")
321
+ print(f"context (text): {dummy_context.shape}")
322
+ print(f"time_aligned_context: {dummy_ta_context.shape}")
323
+ print("--------------------\n")
324
+
325
+ # 4. 执行前向传播
326
+ try:
327
+ with torch.no_grad(): # 在验证时不需要计算梯度
328
+ output = model(
329
+ x=dummy_x,
330
+ timesteps=dummy_timesteps,
331
+ context=dummy_context,
332
+ time_aligned_context=dummy_ta_context
333
+ )
334
+ print("✅ Forward pass successful!")
335
+ print(f"Output shape: {output.shape}")
336
+
337
+ # 5. 验证输出形状
338
+ expected_shape = (batch_size, latent_seq_len, latent_dim)
339
+ assert output.shape == expected_shape, \
340
+ f"Output shape mismatch! Expected {expected_shape}, but got {output.shape}"
341
+ print("✅ Output shape is correct!")
342
+
343
+ except Exception as e:
344
+ print(f"❌ Error during forward pass: {e}")
345
+ import traceback
346
+ traceback.print_exc()
models/dit/mmdit_layers.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from typing import Union
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from einops.layers.torch import Rearrange
11
+
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+
17
+ from .modules import RMSNorm
18
+
19
+ # https://github.com/facebookresearch/DiT
20
+ # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
21
+ # Ref: https://github.com/lucidrains/rotary-embedding-torch
22
+
23
+
24
+ def compute_rope_rotations(length: int,
25
+ dim: int,
26
+ theta: int,
27
+ *,
28
+ freq_scaling: float = 1.0,
29
+ device: Union[torch.device, str] = 'cpu') -> Tensor:
30
+ assert dim % 2 == 0
31
+
32
+ with torch.amp.autocast(device_type='cuda', enabled=False):
33
+ pos = torch.arange(length, dtype=torch.float32, device=device)
34
+ freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
35
+ freqs *= freq_scaling
36
+
37
+ rot = torch.einsum('..., f -> ... f', pos, freqs)
38
+ rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
39
+ rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
40
+ return rot
41
+
42
+
43
+ def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
44
+ with torch.amp.autocast(device_type='cuda', enabled=False):
45
+ _x = x.float()
46
+ _x = _x.view(*_x.shape[:-1], -1, 1, 2)
47
+ x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
48
+ return x_out.reshape(*x.shape).to(dtype=x.dtype)
49
+
50
+
51
+ class TimestepEmbedder(nn.Module):
52
+ """
53
+ Embeds scalar timesteps into vector representations.
54
+ """
55
+
56
+ def __init__(self, dim, frequency_embedding_size, max_period):
57
+ super().__init__()
58
+ self.mlp = nn.Sequential(
59
+ nn.Linear(frequency_embedding_size, dim),
60
+ nn.SiLU(),
61
+ nn.Linear(dim, dim),
62
+ )
63
+ self.dim = dim
64
+ self.max_period = max_period
65
+ assert dim % 2 == 0, 'dim must be even.'
66
+
67
+ with torch.autocast('cuda', enabled=False):
68
+ # 1. 先计算出最终的张量
69
+ initial_freqs = 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
70
+ frequency_embedding_size))
71
+ freq_scale = 10000 / max_period
72
+ freqs_tensor = freq_scale * initial_freqs
73
+
74
+ # 2. 使用 register_buffer() 将最终的张量注册为 buffer
75
+ self.register_buffer('freqs', freqs_tensor, persistent=False)
76
+
77
+ def timestep_embedding(self, t):
78
+ """
79
+ Create sinusoidal timestep embeddings.
80
+ :param t: a 1-D Tensor of N indices, one per batch element.
81
+ These may be fractional.
82
+ :param dim: the dimension of the output.
83
+ :param max_period: controls the minimum frequency of the embeddings.
84
+ :return: an (N, D) Tensor of positional embeddings.
85
+ """
86
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
87
+
88
+ args = t[:, None].float() * self.freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ return embedding
91
+
92
+ def forward(self, t):
93
+ t_freq = self.timestep_embedding(t).to(t.dtype)
94
+ t_emb = self.mlp(t_freq)
95
+ return t_emb
96
+
97
+ class ChannelLastConv1d(nn.Conv1d):
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ x = x.permute(0, 2, 1)
101
+ x = super().forward(x)
102
+ x = x.permute(0, 2, 1)
103
+ return x
104
+
105
+
106
+ # https://github.com/Stability-AI/sd3-ref
107
+ class MLP(nn.Module):
108
+
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ hidden_dim: int,
113
+ multiple_of: int = 256,
114
+ ):
115
+ """
116
+ Initialize the FeedForward module.
117
+
118
+ Args:
119
+ dim (int): Input dimension.
120
+ hidden_dim (int): Hidden dimension of the feedforward layer.
121
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
122
+
123
+ Attributes:
124
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
125
+ w2 (RowParallelLinear): Linear transformation for the second layer.
126
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
127
+
128
+ """
129
+ super().__init__()
130
+ hidden_dim = int(2 * hidden_dim / 3)
131
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
132
+
133
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
134
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
135
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
136
+
137
+ def forward(self, x):
138
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
139
+
140
+
141
+ class ConvMLP(nn.Module):
142
+
143
+ def __init__(
144
+ self,
145
+ dim: int,
146
+ hidden_dim: int,
147
+ multiple_of: int = 256,
148
+ kernel_size: int = 3,
149
+ padding: int = 1,
150
+ ):
151
+ """
152
+ Initialize the FeedForward module.
153
+
154
+ Args:
155
+ dim (int): Input dimension.
156
+ hidden_dim (int): Hidden dimension of the feedforward layer.
157
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
158
+
159
+ Attributes:
160
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
161
+ w2 (RowParallelLinear): Linear transformation for the second layer.
162
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
163
+
164
+ """
165
+ super().__init__()
166
+ hidden_dim = int(2 * hidden_dim / 3)
167
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
168
+
169
+ self.w1 = ChannelLastConv1d(dim,
170
+ hidden_dim,
171
+ bias=False,
172
+ kernel_size=kernel_size,
173
+ padding=padding)
174
+ self.w2 = ChannelLastConv1d(hidden_dim,
175
+ dim,
176
+ bias=False,
177
+ kernel_size=kernel_size,
178
+ padding=padding)
179
+ self.w3 = ChannelLastConv1d(dim,
180
+ hidden_dim,
181
+ bias=False,
182
+ kernel_size=kernel_size,
183
+ padding=padding)
184
+
185
+ def forward(self, x):
186
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
187
+
188
+
189
+
190
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
191
+ return x * (1 + scale) + shift
192
+
193
+
194
+
195
+
196
+ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
197
+ # training will crash without these contiguous calls and the CUDNN limitation
198
+ # I believe this is related to https://github.com/pytorch/pytorch/issues/133974
199
+ # unresolved at the time of writing
200
+ q = q.contiguous()
201
+ k = k.contiguous()
202
+ v = v.contiguous()
203
+ out = F.scaled_dot_product_attention(q, k, v)
204
+ out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
205
+ return out
206
+
207
+
208
+ class SelfAttention(nn.Module):
209
+
210
+ def __init__(self, dim: int, nheads: int):
211
+ super().__init__()
212
+ self.dim = dim
213
+ self.nheads = nheads
214
+
215
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
216
+ self.q_norm = RMSNorm(dim // nheads)
217
+ self.k_norm = RMSNorm(dim // nheads)
218
+
219
+ self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
220
+ h=nheads,
221
+ d=dim // nheads,
222
+ j=3)
223
+
224
+ def pre_attention(
225
+ self, x: torch.Tensor,
226
+ rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
227
+ # x: batch_size * n_tokens * n_channels
228
+ qkv = self.qkv(x)
229
+ q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
230
+ q = q.squeeze(-1)
231
+ k = k.squeeze(-1)
232
+ v = v.squeeze(-1)
233
+ q = self.q_norm(q)
234
+ k = self.k_norm(k)
235
+
236
+ if rot is not None:
237
+ q = apply_rope(q, rot)
238
+ k = apply_rope(k, rot)
239
+
240
+ return q, k, v
241
+
242
+ def forward(
243
+ self,
244
+ x: torch.Tensor, # batch_size * n_tokens * n_channels
245
+ ) -> torch.Tensor:
246
+ q, k, v = self.pre_attention(x)
247
+ out = attention(q, k, v)
248
+ return out
249
+
250
+
251
+ class MMDitSingleBlock(nn.Module):
252
+
253
+ def __init__(self,
254
+ dim: int,
255
+ nhead: int,
256
+ mlp_ratio: float = 4.0,
257
+ pre_only: bool = False,
258
+ kernel_size: int = 7,
259
+ padding: int = 3):
260
+ super().__init__()
261
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
262
+ self.attn = SelfAttention(dim, nhead)
263
+
264
+ self.pre_only = pre_only
265
+ if pre_only:
266
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
267
+ else:
268
+ if kernel_size == 1:
269
+ self.linear1 = nn.Linear(dim, dim)
270
+ else:
271
+ self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
272
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
273
+
274
+ if kernel_size == 1:
275
+ self.ffn = MLP(dim, int(dim * mlp_ratio))
276
+ else:
277
+ self.ffn = ConvMLP(dim,
278
+ int(dim * mlp_ratio),
279
+ kernel_size=kernel_size,
280
+ padding=padding)
281
+
282
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
283
+
284
+ def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
285
+ # x: BS * N * D
286
+ # cond: BS * D
287
+ modulation = self.adaLN_modulation(c)
288
+ if self.pre_only:
289
+ (shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
290
+ gate_msa = shift_mlp = scale_mlp = gate_mlp = None
291
+ else:
292
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
293
+ gate_mlp) = modulation.chunk(6, dim=-1)
294
+
295
+ x = modulate(self.norm1(x), shift_msa, scale_msa)
296
+ q, k, v = self.attn.pre_attention(x, rot)
297
+ return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
298
+
299
+ def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]):
300
+ if self.pre_only:
301
+ return x
302
+
303
+ (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
304
+ x = x + self.linear1(attn_out) * gate_msa
305
+ r = modulate(self.norm2(x), shift_mlp, scale_mlp)
306
+ x = x + self.ffn(r) * gate_mlp
307
+
308
+ return x
309
+ # 这里的forward似乎没有用到
310
+ def forward(self, x: torch.Tensor, cond: torch.Tensor,
311
+ rot: Optional[torch.Tensor]) -> torch.Tensor:
312
+ # x: BS * N * D
313
+ # cond: BS * D
314
+ x_qkv, x_conditions = self.pre_attention(x, cond, rot)
315
+ attn_out = attention(*x_qkv)
316
+ x = self.post_attention(x, attn_out, x_conditions)
317
+
318
+ return x
319
+
320
+
321
+
322
+
323
+ class JointBlock_AT(nn.Module):
324
+ """
325
+ Audio + Text only JointBlock(去掉 clip 分支)
326
+ 返回 (latent, text_f)
327
+ """
328
+ def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
329
+ super().__init__()
330
+ self.pre_only = pre_only
331
+ self.latent_block = MMDitSingleBlock(dim,
332
+ nhead,
333
+ mlp_ratio,
334
+ pre_only=False,
335
+ kernel_size=3,
336
+ padding=1)
337
+ # text_block 仍保留 pre_only 参数(可能是 pre-only 的 AdaLN)
338
+ self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
339
+
340
+ def forward(self, latent: torch.Tensor, text_f: torch.Tensor,
341
+ global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
342
+ # latent: (B, N_latent, D)
343
+ # text_f: (B, N_text, D)
344
+ # global_c: (B, 1, D) or (B, D)
345
+ # extended_c: (B, N_latent, D) or (B, 1, D)
346
+ x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
347
+ # text没有做rope编码, 也有点奇怪,可能audiollm中带有
348
+
349
+ t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
350
+
351
+ latent_len = latent.shape[1]
352
+ text_len = text_f.shape[1]
353
+
354
+ # 只拼接 latent + text
355
+ joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] # dim=2=token dim
356
+
357
+ attn_out = attention(*joint_qkv) # (B, latent_len + text_len, D)
358
+ x_attn_out = attn_out[:, :latent_len] # (B, latent_len, D)
359
+ t_attn_out = attn_out[:, latent_len:] # (B, text_len, D)
360
+
361
+ latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
362
+ if not self.pre_only:
363
+ text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
364
+
365
+ return latent, text_f
366
+
367
+
368
+ # 改一下mask的逻辑
369
+ # def forward(self, latent, text_f, global_c, extended_c, latent_rot,
370
+ # latent_mask: torch.Tensor, text_mask: torch.Tensor):
371
+ # # latent_mask: (B, N_latent) {0,1}
372
+ # # text_mask: (B, N_text) {0,1}
373
+
374
+ # x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
375
+ # t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
376
+
377
+ # latent_len = latent.shape[1]
378
+ # text_len = text_f.shape[1]
379
+
380
+ # # 1) 拼 qkv
381
+ # joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] # 这里假设 token 维=2
382
+
383
+ # # 2) 构造 key mask(拼接后的)
384
+ # key_mask = torch.cat([latent_mask, text_mask], dim=1).bool() # (B, N_total)
385
+
386
+ # # 3) 调用注意力(要求 attention 支持 key_mask)
387
+ # # 若你的 attention 不支持,需要自己在里面对 logits 做 -inf 掩码;示例见后
388
+ # attn_out = attention(*joint_qkv, key_mask=key_mask) # (B, N_total, D)
389
+
390
+ # # 4) 切回两段
391
+ # x_attn_out = attn_out[:, :latent_len, :]
392
+ # t_attn_out = attn_out[:, latent_len:, :]
393
+
394
+ # # 5) 对 query 端输出做屏蔽(避免 padding query 写回)
395
+ # x_attn_out = x_attn_out * latent_mask.unsqueeze(-1) # (B, N_latent, D)
396
+ # t_attn_out = t_attn_out * text_mask.unsqueeze(-1) # (B, N_text, D)
397
+
398
+ # # 6) post_attention 内部**还要**用 query mask 把残差和 FFN 的更新再屏蔽一次(见下一节)
399
+ # latent = self.latent_block.post_attention(latent, x_attn_out, x_mod,
400
+ # query_mask=latent_mask)
401
+ # if not self.text_block.pre_only:
402
+ # text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod,
403
+ # query_mask=text_mask)
404
+
405
+ # return latent, text_f
406
+
407
+
408
+
409
+ class FinalBlock(nn.Module):
410
+
411
+ def __init__(self, dim, out_dim):
412
+ super().__init__()
413
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
414
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
415
+ self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
416
+
417
+ def forward(self, latent, c):
418
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
419
+ latent = modulate(self.norm(latent), shift, scale)
420
+ latent = self.conv(latent)
421
+ return latent
models/dit/modules.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch.cuda.amp import autocast
7
+ import math
8
+ import einops
9
+ from einops import rearrange, repeat
10
+ from inspect import isfunction
11
+
12
+
13
+ def trunc_normal_(tensor, mean, std, a, b):
14
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
15
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
19
+
20
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
21
+ warnings.warn(
22
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
23
+ "The distribution of values may be incorrect.",
24
+ stacklevel=2
25
+ )
26
+
27
+ with torch.no_grad():
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a - mean) / std)
32
+ u = norm_cdf((b - mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+
51
+ # disable in checkpoint mode
52
+ # @torch.jit.script
53
+ def film_modulate(x, shift, scale):
54
+ return x * (1 + scale) + shift
55
+
56
+
57
+ def timestep_embedding(timesteps, dim, max_period=10000):
58
+ """
59
+ Create sinusoidal timestep embeddings.
60
+
61
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
62
+ These may be fractional.
63
+ :param dim: the dimension of the output.
64
+ :param max_period: controls the minimum frequency of the embeddings.
65
+ :return: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ half = dim // 2
68
+ freqs = torch.exp(
69
+ -math.log(max_period) *
70
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
71
+ ).to(device=timesteps.device)
72
+ args = timesteps[:, None].float() * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat([embedding,
76
+ torch.zeros_like(embedding[:, :1])],
77
+ dim=-1)
78
+ return embedding
79
+
80
+
81
+ class TimestepEmbedder(nn.Module):
82
+ """
83
+ Embeds scalar timesteps into vector representations.
84
+ """
85
+ def __init__(
86
+ self, hidden_size, frequency_embedding_size=256, out_size=None
87
+ ):
88
+ super().__init__()
89
+ if out_size is None:
90
+ out_size = hidden_size
91
+ self.mlp = nn.Sequential(
92
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
93
+ nn.SiLU(),
94
+ nn.Linear(hidden_size, out_size, bias=True),
95
+ )
96
+ self.frequency_embedding_size = frequency_embedding_size
97
+
98
+ def forward(self, t):
99
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
100
+ self.mlp[0].weight.dtype
101
+ )
102
+ t_emb = self.mlp(t_freq)
103
+ return t_emb
104
+
105
+
106
+ def patchify(imgs, patch_size, input_type='2d'):
107
+ if input_type == '2d':
108
+ x = einops.rearrange(
109
+ imgs,
110
+ 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)',
111
+ p1=patch_size,
112
+ p2=patch_size
113
+ )
114
+ elif input_type == '1d':
115
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
116
+ return x
117
+
118
+
119
+ def unpatchify(x, channels=3, input_type='2d', img_size=None):
120
+ if input_type == '2d':
121
+ patch_size = int((x.shape[2] // channels)**0.5)
122
+ # h = w = int(x.shape[1] ** .5)
123
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
124
+ assert h * w == x.shape[1] and patch_size**2 * channels == x.shape[2]
125
+ x = einops.rearrange(
126
+ x,
127
+ 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)',
128
+ h=h,
129
+ p1=patch_size,
130
+ p2=patch_size
131
+ )
132
+ elif input_type == '1d':
133
+ patch_size = int((x.shape[2] // channels))
134
+ h = x.shape[1]
135
+ assert patch_size * channels == x.shape[2]
136
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """
142
+ Image to Patch Embedding
143
+ """
144
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
145
+ super().__init__()
146
+ self.patch_size = patch_size
147
+ self.input_type = input_type
148
+ if input_type == '2d':
149
+ self.proj = nn.Conv2d(
150
+ in_chans,
151
+ embed_dim,
152
+ kernel_size=patch_size,
153
+ stride=patch_size,
154
+ bias=True
155
+ )
156
+ elif input_type == '1d':
157
+ self.proj = nn.Conv1d(
158
+ in_chans,
159
+ embed_dim,
160
+ kernel_size=patch_size,
161
+ stride=patch_size,
162
+ bias=True
163
+ )
164
+
165
+ def forward(self, x):
166
+ if self.input_type == '2d':
167
+ B, C, H, W = x.shape
168
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
169
+ elif self.input_type == '1d':
170
+ B, C, H = x.shape
171
+ assert H % self.patch_size == 0
172
+
173
+ x = self.proj(x).flatten(2).transpose(1, 2)
174
+ return x
175
+
176
+
177
+ class PositionalConvEmbedding(nn.Module):
178
+ """
179
+ Convolutional positional embedding used in F5-TTS.
180
+ """
181
+ def __init__(self, dim=768, kernel_size=31, groups=16):
182
+ super().__init__()
183
+ assert kernel_size % 2 != 0
184
+ self.conv1d = nn.Sequential(
185
+ nn.Conv1d(
186
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
187
+ ),
188
+ nn.Mish(),
189
+ nn.Conv1d(
190
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
191
+ ),
192
+ nn.Mish(),
193
+ )
194
+
195
+ def forward(self, x):
196
+ # B T C
197
+ x = self.conv1d(x.transpose(1, 2))
198
+ x = x.transpose(1, 2)
199
+ return x
200
+
201
+
202
+ class SinusoidalPositionalEncoding(nn.Module):
203
+ def __init__(self, dim, length):
204
+ super(SinusoidalPositionalEncoding, self).__init__()
205
+ self.length = length
206
+ self.dim = dim
207
+ self.register_buffer(
208
+ 'pe', self._generate_positional_encoding(length, dim)
209
+ )
210
+
211
+ def _generate_positional_encoding(self, length, dim):
212
+ pe = torch.zeros(length, dim)
213
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
214
+ div_term = torch.exp(
215
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)
216
+ )
217
+
218
+ pe[:, 0::2] = torch.sin(position * div_term)
219
+ pe[:, 1::2] = torch.cos(position * div_term)
220
+
221
+ pe = pe.unsqueeze(0)
222
+ return pe
223
+
224
+ def forward(self, x):
225
+ x = x + self.pe[:, :x.size(1)]
226
+ return x
227
+
228
+
229
+ class PE_wrapper(nn.Module):
230
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
231
+ super().__init__()
232
+ self.method = method
233
+ if method == 'abs':
234
+ # init absolute pe like UViT
235
+ self.length = length
236
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
237
+ trunc_normal_(self.abs_pe, mean=0.0, std=.02, a=-.04, b=.04)
238
+ elif method == 'conv':
239
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
240
+ elif method == 'sinu':
241
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
242
+ elif method == 'none':
243
+ # skip pe
244
+ self.id = nn.Identity()
245
+ else:
246
+ raise NotImplementedError
247
+
248
+ def forward(self, x):
249
+ if self.method == 'abs':
250
+ _, L, _ = x.shape
251
+ assert L <= self.length
252
+ x = x + self.abs_pe[:, :L, :]
253
+ elif self.method == 'conv':
254
+ x = x + self.conv_pe(x)
255
+ elif self.method == 'sinu':
256
+ x = self.sinu_pe(x)
257
+ elif self.method == 'none':
258
+ x = self.id(x)
259
+ else:
260
+ raise NotImplementedError
261
+ return x
262
+
263
+
264
+ class RMSNorm(torch.nn.Module):
265
+ def __init__(self, dim: int, eps: float = 1e-6):
266
+ """
267
+ Initialize the RMSNorm normalization layer.
268
+
269
+ Args:
270
+ dim (int): The dimension of the input tensor.
271
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
272
+
273
+ Attributes:
274
+ eps (float): A small value added to the denominator for numerical stability.
275
+ weight (nn.Parameter): Learnable scaling parameter.
276
+
277
+ """
278
+ super().__init__()
279
+ self.eps = eps
280
+ self.weight = nn.Parameter(torch.ones(dim))
281
+
282
+ def _norm(self, x):
283
+ """
284
+ Apply the RMSNorm normalization to the input tensor.
285
+
286
+ Args:
287
+ x (torch.Tensor): The input tensor.
288
+
289
+ Returns:
290
+ torch.Tensor: The normalized tensor.
291
+
292
+ """
293
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
294
+
295
+ def forward(self, x):
296
+ """
297
+ Forward pass through the RMSNorm layer.
298
+
299
+ Args:
300
+ x (torch.Tensor): The input tensor.
301
+
302
+ Returns:
303
+ torch.Tensor: The output tensor after applying RMSNorm.
304
+
305
+ """
306
+ output = self._norm(x.float()).type_as(x)
307
+ return output * self.weight
308
+
309
+
310
+ class GELU(nn.Module):
311
+ def __init__(
312
+ self,
313
+ dim_in: int,
314
+ dim_out: int,
315
+ approximate: str = "none",
316
+ bias: bool = True
317
+ ):
318
+ super().__init__()
319
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
320
+ self.approximate = approximate
321
+
322
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
323
+ if gate.device.type != "mps":
324
+ return F.gelu(gate, approximate=self.approximate)
325
+ # mps: gelu is not implemented for float16
326
+ return F.gelu(
327
+ gate.to(dtype=torch.float32), approximate=self.approximate
328
+ ).to(dtype=gate.dtype)
329
+
330
+ def forward(self, hidden_states):
331
+ hidden_states = self.proj(hidden_states)
332
+ hidden_states = self.gelu(hidden_states)
333
+ return hidden_states
334
+
335
+
336
+ class GEGLU(nn.Module):
337
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
338
+ super().__init__()
339
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
340
+
341
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
342
+ if gate.device.type != "mps":
343
+ return F.gelu(gate)
344
+ # mps: gelu is not implemented for float16
345
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
346
+
347
+ def forward(self, hidden_states):
348
+ hidden_states = self.proj(hidden_states)
349
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
350
+ return hidden_states * self.gelu(gate)
351
+
352
+
353
+ class ApproximateGELU(nn.Module):
354
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
355
+ super().__init__()
356
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
357
+
358
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
359
+ x = self.proj(x)
360
+ return x * torch.sigmoid(1.702 * x)
361
+
362
+
363
+ # disable in checkpoint mode
364
+ # @torch.jit.script
365
+ def snake_beta(x, alpha, beta):
366
+ return x + beta * torch.sin(x * alpha).pow(2)
367
+
368
+
369
+ class Snake(nn.Module):
370
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
371
+ super().__init__()
372
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
373
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
374
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
375
+ self.alpha.requires_grad = alpha_trainable
376
+ self.beta.requires_grad = alpha_trainable
377
+
378
+ def forward(self, x):
379
+ x = self.proj(x)
380
+ x = snake_beta(x, self.alpha, self.beta)
381
+ return x
382
+
383
+
384
+ class GESnake(nn.Module):
385
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
386
+ super().__init__()
387
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
388
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
389
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
390
+ self.alpha.requires_grad = alpha_trainable
391
+ self.beta.requires_grad = alpha_trainable
392
+
393
+ def forward(self, x):
394
+ x = self.proj(x)
395
+ x, gate = x.chunk(2, dim=-1)
396
+ return x * snake_beta(gate, self.alpha, self.beta)
397
+
398
+
399
+ class FeedForward(nn.Module):
400
+ def __init__(
401
+ self,
402
+ dim,
403
+ dim_out=None,
404
+ mult=4,
405
+ dropout=0.0,
406
+ activation_fn="geglu",
407
+ final_dropout=False,
408
+ inner_dim=None,
409
+ bias=True,
410
+ ):
411
+ super().__init__()
412
+ if inner_dim is None:
413
+ inner_dim = int(dim * mult)
414
+ dim_out = dim_out if dim_out is not None else dim
415
+
416
+ if activation_fn == "gelu":
417
+ act_fn = GELU(dim, inner_dim, bias=bias)
418
+ elif activation_fn == "gelu-approximate":
419
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
420
+ elif activation_fn == "geglu":
421
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
422
+ elif activation_fn == "geglu-approximate":
423
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
424
+ elif activation_fn == "snake":
425
+ act_fn = Snake(dim, inner_dim, bias=bias)
426
+ elif activation_fn == "gesnake":
427
+ act_fn = GESnake(dim, inner_dim, bias=bias)
428
+ else:
429
+ raise NotImplementedError
430
+
431
+ self.net = nn.ModuleList([])
432
+ # project in
433
+ self.net.append(act_fn)
434
+ # project dropout
435
+ self.net.append(nn.Dropout(dropout))
436
+ # project out
437
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
438
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
439
+ if final_dropout:
440
+ self.net.append(nn.Dropout(dropout))
441
+
442
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
443
+ for module in self.net:
444
+ hidden_states = module(hidden_states)
445
+ return hidden_states
models/dit/rotary.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ "this rope is faster than llama rope with jit script"
3
+
4
+
5
+ def rotate_half(x):
6
+ x1, x2 = x.chunk(2, dim=-1)
7
+ return torch.cat((-x2, x1), dim=-1)
8
+
9
+
10
+ # disable in checkpoint mode
11
+ # @torch.jit.script
12
+ def apply_rotary_pos_emb(x, cos, sin):
13
+ # NOTE: This could probably be moved to Triton
14
+ # Handle a possible sequence length mismatch in between q and k
15
+ cos = cos[:, :, :x.shape[-2], :]
16
+ sin = sin[:, :, :x.shape[-2], :]
17
+ return (x*cos) + (rotate_half(x) * sin)
18
+
19
+
20
+ class RotaryEmbedding(torch.nn.Module):
21
+ """
22
+ The rotary position embeddings from RoFormer_ (Su et. al).
23
+ A crucial insight from the method is that the query and keys are
24
+ transformed by rotation matrices which depend on the relative positions.
25
+
26
+ Other implementations are available in the Rotary Transformer repo_ and in
27
+ GPT-NeoX_, GPT-NeoX was an inspiration
28
+
29
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
30
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
31
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
32
+
33
+
34
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
35
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
36
+ """
37
+ def __init__(self, dim: int):
38
+ super().__init__()
39
+ # Generate and save the inverse frequency buffer (non trainable)
40
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+ self._seq_len_cached = None
43
+ self._cos_cached = None
44
+ self._sin_cached = None
45
+
46
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
47
+ # expect input: B, H, L, D
48
+ seq_len = x.shape[seq_dimension]
49
+
50
+ # Reset the tables if the sequence length has changed,
51
+ # or if we're on a new device (possibly due to tracing for instance)
52
+ # also make sure dtype wont change
53
+ if (
54
+ seq_len != self._seq_len_cached or
55
+ self._cos_cached.device != x.device or
56
+ self._cos_cached.dtype != x.dtype
57
+ ):
58
+ self._seq_len_cached = seq_len
59
+ t = torch.arange(
60
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
61
+ )
62
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
63
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
64
+
65
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
66
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
67
+
68
+ return self._cos_cached, self._sin_cached
69
+
70
+ def forward(self, q, k):
71
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
72
+ q.float(), seq_dimension=-2
73
+ )
74
+ if k is not None:
75
+ return (
76
+ apply_rotary_pos_emb(
77
+ q.float(), self._cos_cached, self._sin_cached
78
+ ).type_as(q),
79
+ apply_rotary_pos_emb(
80
+ k.float(), self._cos_cached, self._sin_cached
81
+ ).type_as(k),
82
+ )
83
+ else:
84
+ return (
85
+ apply_rotary_pos_emb(
86
+ q.float(), self._cos_cached, self._sin_cached
87
+ ).type_as(q), None
88
+ )
models/dit/span_mask.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ # Convert mask_prob to a NumPy array
41
+ mask_prob = np.array(mask_prob)
42
+
43
+ # Calculate all_num_mask for each element in the batch
44
+ all_num_mask = np.floor(
45
+ mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)
46
+ ).astype(int)
47
+
48
+ # Apply the max operation with min_masks for each element
49
+ all_num_mask = np.maximum(min_masks, all_num_mask)
50
+
51
+ mask_idcs = []
52
+ for i in range(bsz):
53
+ if padding_mask is not None:
54
+ sz = all_sz - padding_mask[i].long().sum().item()
55
+ num_mask = int(
56
+ # add a random number for probabilistic rounding
57
+ mask_prob * sz / float(mask_length) + np.random.rand()
58
+ )
59
+ num_mask = max(min_masks, num_mask)
60
+ else:
61
+ sz = all_sz
62
+ num_mask = all_num_mask[i]
63
+
64
+ if mask_type == "static":
65
+ lengths = np.full(num_mask, mask_length)
66
+ elif mask_type == "uniform":
67
+ lengths = np.random.randint(
68
+ mask_other, mask_length*2 + 1, size=num_mask
69
+ )
70
+ elif mask_type == "normal":
71
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
72
+ lengths = [max(1, int(round(x))) for x in lengths]
73
+ elif mask_type == "poisson":
74
+ lengths = np.random.poisson(mask_length, size=num_mask)
75
+ lengths = [int(round(x)) for x in lengths]
76
+ else:
77
+ raise Exception("unknown mask selection " + mask_type)
78
+
79
+ if sum(lengths) == 0:
80
+ lengths[0] = min(mask_length, sz - 1)
81
+
82
+ if no_overlap:
83
+ mask_idc = []
84
+
85
+ def arrange(s, e, length, keep_length):
86
+ span_start = np.random.randint(s, e - length)
87
+ mask_idc.extend(span_start + i for i in range(length))
88
+
89
+ new_parts = []
90
+ if span_start - s - min_space >= keep_length:
91
+ new_parts.append((s, span_start - min_space + 1))
92
+ if e - span_start - keep_length - min_space > keep_length:
93
+ new_parts.append((span_start + length + min_space, e))
94
+ return new_parts
95
+
96
+ parts = [(0, sz)]
97
+ min_length = min(lengths)
98
+ for length in sorted(lengths, reverse=True):
99
+ lens = np.fromiter(
100
+ (
101
+ e - s if e - s >= length + min_space else 0
102
+ for s, e in parts
103
+ ),
104
+ np.int,
105
+ )
106
+ l_sum = np.sum(lens)
107
+ if l_sum == 0:
108
+ break
109
+ probs = lens / np.sum(lens)
110
+ c = np.random.choice(len(parts), p=probs)
111
+ s, e = parts.pop(c)
112
+ parts.extend(arrange(s, e, length, min_length))
113
+ mask_idc = np.asarray(mask_idc)
114
+ else:
115
+ min_len = min(lengths)
116
+ if sz - min_len <= num_mask:
117
+ min_len = sz - num_mask - 1
118
+
119
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
120
+
121
+ mask_idc = np.asarray([
122
+ mask_idc[j] + offset for j in range(len(mask_idc))
123
+ for offset in range(lengths[j])
124
+ ])
125
+
126
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
127
+ # min_len = min([len(m) for m in mask_idcs])
128
+ for i, mask_idc in enumerate(mask_idcs):
129
+ # if len(mask_idc) > min_len:
130
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
131
+ mask[i, mask_idc] = True
132
+
133
+ return torch.tensor(mask)
134
+
135
+
136
+ if __name__ == '__main__':
137
+ mask = compute_mask_indices(
138
+ shape=[4, 500],
139
+ padding_mask=None,
140
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
141
+ mask_length=10,
142
+ mask_type="static",
143
+ mask_other=0.0,
144
+ min_masks=1,
145
+ no_overlap=False,
146
+ min_space=0,
147
+ )
148
+ print(mask)
149
+ print(mask.sum(dim=1))
models/flow_matching.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union, List, Sequence
2
+
3
+ import inspect
4
+ import random
5
+
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import copy
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+ from diffusers import FlowMatchEulerDiscreteScheduler
15
+ from diffusers.training_utils import compute_density_for_timestep_sampling
16
+
17
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
18
+ from models.content_encoder.content_encoder import ContentEncoder
19
+ from models.content_adapter import ContentAdapterBase
20
+ from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
21
+ from utils.torch_utilities import (
22
+ create_alignment_path, create_mask_from_length, loss_with_mask,
23
+ trim_or_pad_length
24
+ )
25
+ from constants import SAME_LENGTH_TASKS
26
+
27
+
28
+ class FlowMatchingMixin:
29
+ def __init__(
30
+ self,
31
+ cfg_drop_ratio: float = 0.2,
32
+ sample_strategy: str = 'normal',
33
+ num_train_steps: int = 1000
34
+ ) -> None:
35
+ r"""
36
+ Args:
37
+ cfg_drop_ratio (float): Dropout ratio for the autoencoder.
38
+ sample_strategy (str): Sampling strategy for timesteps during training.
39
+ num_train_steps (int): Number of training steps for the noise scheduler.
40
+ """
41
+ self.sample_strategy = sample_strategy
42
+ self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler(
43
+ num_train_timesteps=num_train_steps
44
+ )
45
+ self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler)
46
+
47
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
48
+ self.cfg_drop_ratio = cfg_drop_ratio
49
+
50
+ def get_input_target_and_timesteps(
51
+ self,
52
+ latent: torch.Tensor,
53
+ training: bool,
54
+ ):
55
+ batch_size = latent.shape[0]
56
+ noise = torch.randn_like(latent)
57
+
58
+ if training:
59
+ if self.sample_strategy == 'normal':
60
+ u = compute_density_for_timestep_sampling(
61
+ weighting_scheme="logit_normal",
62
+ batch_size=batch_size,
63
+ logit_mean=0,
64
+ logit_std=1,
65
+ mode_scale=None,
66
+ )
67
+ elif self.sample_strategy == 'uniform':
68
+ u = torch.rand(batch_size, )
69
+ else:
70
+ raise NotImplementedError(
71
+ f"{self.sample_strategy} samlping for timesteps is not supported now"
72
+ )
73
+
74
+ indices = (
75
+ u * self.train_noise_scheduler.config.num_train_timesteps
76
+ ).long()
77
+ else:
78
+ indices = (
79
+ self.train_noise_scheduler.config.num_train_timesteps // 2
80
+ ) * torch.ones((batch_size, )).long()
81
+
82
+ # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval
83
+ timesteps = self.train_noise_scheduler.timesteps[indices].to(
84
+ device=latent.device
85
+ )
86
+ sigmas = self.get_sigmas(
87
+ timesteps, n_dim=latent.ndim, dtype=latent.dtype
88
+ )
89
+
90
+ noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
91
+
92
+ target = noise - latent
93
+
94
+ return noisy_latent, target, timesteps
95
+
96
+ def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
97
+ device = timesteps.device
98
+
99
+ # a list from 1 declining to 1/num_train_steps
100
+ sigmas = self.train_noise_scheduler.sigmas.to(
101
+ device=device, dtype=dtype
102
+ )
103
+
104
+ schedule_timesteps = self.train_noise_scheduler.timesteps.to(device)
105
+ timesteps = timesteps.to(device)
106
+ step_indices = [(schedule_timesteps == t).nonzero().item()
107
+ for t in timesteps]
108
+
109
+ sigma = sigmas[step_indices].flatten()
110
+ while len(sigma.shape) < n_dim:
111
+ sigma = sigma.unsqueeze(-1)
112
+ return sigma
113
+
114
+ def retrieve_timesteps(
115
+ self,
116
+ num_inference_steps: Optional[int] = None,
117
+ device: Optional[Union[str, torch.device]] = None,
118
+ timesteps: Optional[List[int]] = None,
119
+ sigmas: Optional[List[float]] = None,
120
+ **kwargs,
121
+ ):
122
+ # used in inference, retrieve new timesteps on given inference timesteps
123
+ scheduler = self.infer_noise_scheduler
124
+
125
+ if timesteps is not None and sigmas is not None:
126
+ raise ValueError(
127
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
128
+ )
129
+ if timesteps is not None:
130
+ accepts_timesteps = "timesteps" in set(
131
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
132
+ )
133
+ if not accepts_timesteps:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" timestep schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(
139
+ timesteps=timesteps, device=device, **kwargs
140
+ )
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(
145
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
146
+ )
147
+ if not accept_sigmas:
148
+ raise ValueError(
149
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
150
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
151
+ )
152
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ num_inference_steps = len(timesteps)
155
+ else:
156
+ scheduler.set_timesteps(
157
+ num_inference_steps, device=device, **kwargs
158
+ )
159
+ timesteps = scheduler.timesteps
160
+ return timesteps, num_inference_steps
161
+
162
+
163
+ class ContentEncoderAdapterMixin:
164
+ def __init__(
165
+ self,
166
+ content_encoder: ContentEncoder,
167
+ content_adapter: ContentAdapterBase | None = None
168
+ ):
169
+ self.content_encoder = content_encoder
170
+ self.content_adapter = content_adapter
171
+
172
+ def encode_content(
173
+ self,
174
+ content: list[Any],
175
+ task: list[str],
176
+ device: str | torch.device,
177
+ instruction: torch.Tensor | None = None,
178
+ instruction_lengths: torch.Tensor | None = None
179
+ ):
180
+ content_output: dict[
181
+ str, torch.Tensor] = self.content_encoder.encode_content(
182
+ content, task, device=device
183
+ )
184
+ content, content_mask = content_output["content"], content_output[
185
+ "content_mask"]
186
+
187
+ if instruction is not None:
188
+ instruction_mask = create_mask_from_length(instruction_lengths)
189
+ (
190
+ content,
191
+ content_mask,
192
+ global_duration_pred,
193
+ local_duration_pred,
194
+ ) = self.content_adapter(
195
+ content, content_mask, instruction, instruction_mask
196
+ )
197
+
198
+ return_dict = {
199
+ "content": content,
200
+ "content_mask": content_mask,
201
+ "length_aligned_content": content_output["length_aligned_content"],
202
+ }
203
+ if instruction is not None:
204
+ return_dict["global_duration_pred"] = global_duration_pred
205
+ return_dict["local_duration_pred"] = local_duration_pred
206
+
207
+ return return_dict
208
+
209
+
210
+ class SingleTaskCrossAttentionAudioFlowMatching(
211
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
212
+ FlowMatchingMixin, ContentEncoderAdapterMixin
213
+ ):
214
+ def __init__(
215
+ self,
216
+ autoencoder: nn.Module,
217
+ content_encoder: ContentEncoder,
218
+ backbone: nn.Module,
219
+ cfg_drop_ratio: float = 0.2,
220
+ sample_strategy: str = 'normal',
221
+ num_train_steps: int = 1000,
222
+ ):
223
+ nn.Module.__init__(self)
224
+ FlowMatchingMixin.__init__(
225
+ self, cfg_drop_ratio, sample_strategy, num_train_steps
226
+ )
227
+ ContentEncoderAdapterMixin.__init__(
228
+ self, content_encoder=content_encoder
229
+ )
230
+
231
+ self.autoencoder = autoencoder
232
+ for param in self.autoencoder.parameters():
233
+ param.requires_grad = False
234
+
235
+ if hasattr(
236
+ self.content_encoder, "audio_encoder"
237
+ ) and self.content_encoder.audio_encoder is not None:
238
+ self.content_encoder.audio_encoder.model = self.autoencoder
239
+
240
+ self.backbone = backbone
241
+ self.dummy_param = nn.Parameter(torch.empty(0))
242
+
243
+ def forward(
244
+ self, content: list[Any], condition: list[Any], task: list[str],
245
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
246
+ ):
247
+ device = self.dummy_param.device
248
+
249
+ self.autoencoder.eval()
250
+ with torch.no_grad():
251
+ latent, latent_mask = self.autoencoder.encode(
252
+ waveform.unsqueeze(1), waveform_lengths
253
+ )
254
+
255
+ content_dict = self.encode_content(content, task, device)
256
+ content, content_mask = content_dict["content"], content_dict[
257
+ "content_mask"]
258
+
259
+ if self.training and self.classifier_free_guidance:
260
+ mask_indices = [
261
+ k for k in range(len(waveform))
262
+ if random.random() < self.cfg_drop_ratio
263
+ ]
264
+ if len(mask_indices) > 0:
265
+ content[mask_indices] = 0
266
+
267
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
268
+ latent, training=self.training
269
+ )
270
+
271
+ pred: torch.Tensor = self.backbone(
272
+ x=noisy_latent,
273
+ timesteps=timesteps,
274
+ context=content,
275
+ x_mask=latent_mask,
276
+ context_mask=content_mask
277
+ )
278
+
279
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
280
+ loss = loss_with_mask(loss, latent_mask)
281
+
282
+ return loss
283
+
284
+ def iterative_denoise(
285
+ self, latent: torch.Tensor, timesteps: list[int], num_steps: int,
286
+ verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict
287
+ ):
288
+ progress_bar = tqdm(range(num_steps), disable=not verbose)
289
+
290
+ for i, timestep in enumerate(timesteps):
291
+ # expand the latent if we are doing classifier free guidance
292
+ if cfg:
293
+ latent_input = torch.cat([latent, latent])
294
+ else:
295
+ latent_input = latent
296
+
297
+ noise_pred: torch.Tensor = self.backbone(
298
+ x=latent_input, timesteps=timestep, **backbone_input
299
+ )
300
+
301
+ # perform guidance
302
+ if cfg:
303
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
304
+ noise_pred = noise_pred_uncond + cfg_scale * (
305
+ noise_pred_content - noise_pred_uncond
306
+ )
307
+
308
+ latent = self.infer_noise_scheduler.step(
309
+ noise_pred, timestep, latent
310
+ ).prev_sample
311
+
312
+ progress_bar.update(1)
313
+
314
+ progress_bar.close()
315
+
316
+ return latent
317
+
318
+ @torch.no_grad()
319
+ def inference(
320
+ self,
321
+ content: list[Any],
322
+ condition: list[Any],
323
+ task: list[str],
324
+ latent_shape: Sequence[int],
325
+ num_steps: int = 50,
326
+ sway_sampling_coef: float | None = -1.0,
327
+ guidance_scale: float = 3.0,
328
+ num_samples_per_content: int = 1,
329
+ disable_progress: bool = True,
330
+ **kwargs
331
+ ):
332
+ device = self.dummy_param.device
333
+ classifier_free_guidance = guidance_scale > 1.0
334
+ batch_size = len(content) * num_samples_per_content
335
+
336
+ if classifier_free_guidance:
337
+ content, content_mask = self.encode_content_classifier_free(
338
+ content, task, num_samples_per_content
339
+ )
340
+ else:
341
+ content_output: dict[
342
+ str, torch.Tensor] = self.content_encoder.encode_content(
343
+ content, task
344
+ )
345
+ content, content_mask = content_output["content"], content_output[
346
+ "content_mask"]
347
+ content = content.repeat_interleave(num_samples_per_content, 0)
348
+ content_mask = content_mask.repeat_interleave(
349
+ num_samples_per_content, 0
350
+ )
351
+
352
+ latent = self.prepare_latent(
353
+ batch_size, latent_shape, content.dtype, device
354
+ )
355
+
356
+ if not sway_sampling_coef:
357
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
358
+ else:
359
+ t = torch.linspace(0, 1, num_steps + 1)
360
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
361
+ sigmas = 1 - t
362
+ timesteps, num_steps = self.retrieve_timesteps(
363
+ num_steps, device, timesteps=None, sigmas=sigmas
364
+ )
365
+
366
+ latent = self.iterative_denoise(
367
+ latent=latent,
368
+ timesteps=timesteps,
369
+ num_steps=num_steps,
370
+ verbose=not disable_progress,
371
+ cfg=classifier_free_guidance,
372
+ cfg_scale=guidance_scale,
373
+ backbone_input={
374
+ "context": content,
375
+ "context_mask": content_mask,
376
+ },
377
+ )
378
+
379
+ waveform = self.autoencoder.decode(latent)
380
+
381
+ return waveform
382
+
383
+ def prepare_latent(
384
+ self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype,
385
+ device: str
386
+ ):
387
+ shape = (batch_size, *latent_shape)
388
+ latent = randn_tensor(
389
+ shape, generator=None, device=device, dtype=dtype
390
+ )
391
+ return latent
392
+
393
+ def encode_content_classifier_free(
394
+ self,
395
+ content: list[Any],
396
+ task: list[str],
397
+ device,
398
+ num_samples_per_content: int = 1
399
+ ):
400
+ content_dict = self.content_encoder.encode_content(
401
+ content, task, device=device
402
+ )
403
+ content, content_mask = content_dict["content"], content_dict[
404
+ "content_mask"]
405
+
406
+ content = content.repeat_interleave(num_samples_per_content, 0)
407
+ content_mask = content_mask.repeat_interleave(
408
+ num_samples_per_content, 0
409
+ )
410
+
411
+ # get unconditional embeddings for classifier free guidance
412
+ uncond_content = torch.zeros_like(content)
413
+ uncond_content_mask = content_mask.detach().clone()
414
+
415
+ uncond_content = uncond_content.repeat_interleave(
416
+ num_samples_per_content, 0
417
+ )
418
+ uncond_content_mask = uncond_content_mask.repeat_interleave(
419
+ num_samples_per_content, 0
420
+ )
421
+
422
+ # For classifier free guidance, we need to do two forward passes.
423
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
424
+ content = torch.cat([uncond_content, content])
425
+ content_mask = torch.cat([uncond_content_mask, content_mask])
426
+
427
+ return content, content_mask
428
+
429
+
430
+ class DurationAdapterMixin:
431
+ def __init__(
432
+ self,
433
+ latent_token_rate: int,
434
+ offset: float = 1.0,
435
+ frame_resolution: float | None = None
436
+ ):
437
+ self.latent_token_rate = latent_token_rate
438
+ self.offset = offset
439
+ self.frame_resolution = frame_resolution
440
+
441
+ def get_global_duration_loss(
442
+ self,
443
+ pred: torch.Tensor,
444
+ latent_mask: torch.Tensor,
445
+ reduce: bool = True,
446
+ ):
447
+ target = torch.log(
448
+ latent_mask.sum(1) / self.latent_token_rate + self.offset
449
+ )
450
+ loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
451
+ return loss
452
+
453
+ def get_local_duration_loss(
454
+ self, ground_truth: torch.Tensor, pred: torch.Tensor,
455
+ mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
456
+ ):
457
+ n_frames = torch.round(ground_truth / self.frame_resolution)
458
+ target = torch.log(n_frames + self.offset)
459
+ loss = loss_with_mask(
460
+ (target - pred)**2,
461
+ mask,
462
+ reduce=False,
463
+ )
464
+ loss *= is_time_aligned
465
+ if reduce:
466
+ if is_time_aligned.sum().item() == 0:
467
+ loss *= 0.0
468
+ loss = loss.mean()
469
+ else:
470
+ loss = loss.sum() / is_time_aligned.sum()
471
+
472
+ return loss
473
+
474
+ def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
475
+ pred = torch.exp(pred) * mask
476
+ pred = torch.ceil(pred) - self.offset
477
+ pred *= self.frame_resolution
478
+ return pred
479
+
480
+ def prepare_global_duration(
481
+ self,
482
+ global_pred: torch.Tensor,
483
+ local_pred: torch.Tensor,
484
+ is_time_aligned: Sequence[bool],
485
+ use_local: bool = True,
486
+ ):
487
+ """
488
+ global_pred: predicted duration value, processed by logarithmic and offset
489
+ local_pred: predicted latent length
490
+ """
491
+ global_pred = torch.exp(global_pred) - self.offset
492
+ result = global_pred
493
+ # avoid error accumulation for each frame
494
+ if use_local:
495
+ pred_from_local = torch.round(local_pred * self.latent_token_rate)
496
+ pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
497
+ result[is_time_aligned] = pred_from_local[is_time_aligned]
498
+
499
+ return result
500
+
501
+ def expand_by_duration(
502
+ self,
503
+ x: torch.Tensor,
504
+ content_mask: torch.Tensor,
505
+ local_duration: torch.Tensor,
506
+ global_duration: torch.Tensor | None = None,
507
+ ):
508
+ n_latents = torch.round(local_duration * self.latent_token_rate)
509
+ if global_duration is not None:
510
+ latent_length = torch.round(
511
+ global_duration * self.latent_token_rate
512
+ )
513
+ else:
514
+ latent_length = n_latents.sum(1)
515
+ latent_mask = create_mask_from_length(latent_length).to(
516
+ content_mask.device
517
+ )
518
+ attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
519
+ align_path = create_alignment_path(n_latents, attn_mask)
520
+ expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
521
+ return expanded_x, latent_mask
522
+
523
+
524
+ class CrossAttentionAudioFlowMatching(
525
+ SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin
526
+ ):
527
+ def __init__(
528
+ self,
529
+ autoencoder: AutoEncoderBase,
530
+ content_encoder: ContentEncoder,
531
+ content_adapter: ContentAdapterBase,
532
+ backbone: nn.Module,
533
+ content_dim: int,
534
+ frame_resolution: float,
535
+ duration_offset: float = 1.0,
536
+ cfg_drop_ratio: float = 0.2,
537
+ sample_strategy: str = 'normal',
538
+ num_train_steps: int = 1000
539
+ ):
540
+ super().__init__(
541
+ autoencoder=autoencoder,
542
+ content_encoder=content_encoder,
543
+ backbone=backbone,
544
+ cfg_drop_ratio=cfg_drop_ratio,
545
+ sample_strategy=sample_strategy,
546
+ num_train_steps=num_train_steps,
547
+ )
548
+ ContentEncoderAdapterMixin.__init__(
549
+ self,
550
+ content_encoder=content_encoder,
551
+ content_adapter=content_adapter
552
+ )
553
+ DurationAdapterMixin.__init__(
554
+ self,
555
+ latent_token_rate=autoencoder.latent_token_rate,
556
+ offset=duration_offset
557
+ )
558
+
559
+ def encode_content_with_instruction(
560
+ self, content: list[Any], task: list[str], device,
561
+ instruction: torch.Tensor, instruction_lengths: torch.Tensor
562
+ ):
563
+ content_dict = self.encode_content(
564
+ content, task, device, instruction, instruction_lengths
565
+ )
566
+ return (
567
+ content_dict["content"],
568
+ content_dict["content_mask"],
569
+ content_dict["global_duration_pred"],
570
+ content_dict["local_duration_pred"],
571
+ content_dict["length_aligned_content"],
572
+ )
573
+
574
+ def forward(
575
+ self,
576
+ content: list[Any],
577
+ task: list[str],
578
+ waveform: torch.Tensor,
579
+ waveform_lengths: torch.Tensor,
580
+ instruction: torch.Tensor,
581
+ instruction_lengths: torch.Tensor,
582
+ loss_reduce: bool = True,
583
+ **kwargs
584
+ ):
585
+ device = self.dummy_param.device
586
+ loss_reduce = self.training or (loss_reduce and not self.training)
587
+
588
+ self.autoencoder.eval()
589
+ with torch.no_grad():
590
+ latent, latent_mask = self.autoencoder.encode(
591
+ waveform.unsqueeze(1), waveform_lengths
592
+ )
593
+
594
+ content, content_mask, global_duration_pred, _, _ = \
595
+ self.encode_content_with_instruction(
596
+ content, task, device, instruction, instruction_lengths
597
+ )
598
+
599
+ global_duration_loss = self.get_global_duration_loss(
600
+ global_duration_pred, latent_mask, reduce=loss_reduce
601
+ )
602
+
603
+ if self.training and self.classifier_free_guidance:
604
+ mask_indices = [
605
+ k for k in range(len(waveform))
606
+ if random.random() < self.cfg_drop_ratio
607
+ ]
608
+ if len(mask_indices) > 0:
609
+ content[mask_indices] = 0
610
+
611
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
612
+ latent, training=self.training
613
+ )
614
+
615
+ pred: torch.Tensor = self.backbone(
616
+ x=noisy_latent,
617
+ timesteps=timesteps,
618
+ context=content,
619
+ x_mask=latent_mask,
620
+ context_mask=content_mask,
621
+ )
622
+ pred = pred.transpose(1, self.autoencoder.time_dim)
623
+ target = target.transpose(1, self.autoencoder.time_dim)
624
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
625
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
626
+
627
+ return {
628
+ "diff_loss": diff_loss,
629
+ "global_duration_loss": global_duration_loss,
630
+ }
631
+
632
+ @torch.no_grad()
633
+ def inference(
634
+ self,
635
+ content: list[Any],
636
+ condition: list[Any],
637
+ task: list[str],
638
+ is_time_aligned: Sequence[bool],
639
+ instruction: torch.Tensor,
640
+ instruction_lengths: torch.Tensor,
641
+ num_steps: int = 20,
642
+ sway_sampling_coef: float | None = -1.0,
643
+ guidance_scale: float = 3.0,
644
+ disable_progress=True,
645
+ use_gt_duration: bool = False,
646
+ **kwargs
647
+ ):
648
+ device = self.dummy_param.device
649
+ classifier_free_guidance = guidance_scale > 1.0
650
+
651
+ (
652
+ content,
653
+ content_mask,
654
+ global_duration_pred,
655
+ local_duration_pred,
656
+ _,
657
+ ) = self.encode_content_with_instruction(
658
+ content, task, device, instruction, instruction_lengths
659
+ )
660
+ batch_size = content.size(0)
661
+
662
+ if use_gt_duration:
663
+ raise NotImplementedError(
664
+ "Using ground truth global duration only is not implemented yet"
665
+ )
666
+
667
+ # prepare global duration
668
+ global_duration = self.prepare_global_duration(
669
+ global_duration_pred,
670
+ local_duration_pred,
671
+ is_time_aligned,
672
+ use_local=False
673
+ )
674
+ # TODO: manually set duration for SE and AudioSR
675
+ latent_length = torch.round(global_duration * self.latent_token_rate)
676
+ task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task])
677
+ latent_length[task_mask] = content[task_mask].size(1)
678
+ latent_mask = create_mask_from_length(latent_length).to(device)
679
+ max_latent_length = latent_mask.sum(1).max().item()
680
+
681
+ # prepare latent and noise
682
+ if classifier_free_guidance:
683
+ uncond_context = torch.zeros_like(content)
684
+ uncond_content_mask = content_mask.detach().clone()
685
+ context = torch.cat([uncond_context, content])
686
+ context_mask = torch.cat([uncond_content_mask, content_mask])
687
+ else:
688
+ context = content
689
+ context_mask = content_mask
690
+
691
+ latent_shape = tuple(
692
+ max_latent_length if dim is None else dim
693
+ for dim in self.autoencoder.latent_shape
694
+ )
695
+ shape = (batch_size, *latent_shape)
696
+ latent = randn_tensor(
697
+ shape, generator=None, device=device, dtype=content.dtype
698
+ )
699
+ if not sway_sampling_coef:
700
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
701
+ else:
702
+ t = torch.linspace(0, 1, num_steps + 1)
703
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
704
+ sigmas = 1 - t
705
+ timesteps, num_steps = self.retrieve_timesteps(
706
+ num_steps, device, timesteps=None, sigmas=sigmas
707
+ )
708
+ latent = self.iterative_denoise(
709
+ latent=latent,
710
+ timesteps=timesteps,
711
+ num_steps=num_steps,
712
+ verbose=not disable_progress,
713
+ cfg=classifier_free_guidance,
714
+ cfg_scale=guidance_scale,
715
+ backbone_input={
716
+ "x_mask": latent_mask,
717
+ "context": context,
718
+ "context_mask": context_mask,
719
+ }
720
+ )
721
+
722
+ waveform = self.autoencoder.decode(latent)
723
+ return waveform
724
+
725
+
726
+ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
727
+ def __init__(
728
+ self,
729
+ autoencoder: AutoEncoderBase,
730
+ content_encoder: ContentEncoder,
731
+ content_adapter: ContentAdapterBase,
732
+ backbone: nn.Module,
733
+ content_dim: int,
734
+ frame_resolution: float,
735
+ duration_offset: float = 1.0,
736
+ cfg_drop_ratio: float = 0.2,
737
+ sample_strategy: str = 'normal',
738
+ num_train_steps: int = 1000
739
+ ):
740
+
741
+ super().__init__(
742
+ autoencoder=autoencoder,
743
+ content_encoder=content_encoder,
744
+ content_adapter=content_adapter,
745
+ backbone=backbone,
746
+ content_dim=content_dim,
747
+ frame_resolution=frame_resolution,
748
+ duration_offset=duration_offset,
749
+ cfg_drop_ratio=cfg_drop_ratio,
750
+ sample_strategy=sample_strategy,
751
+ num_train_steps=num_train_steps
752
+ )
753
+ DurationAdapterMixin.__init__(
754
+ self,
755
+ latent_token_rate=autoencoder.latent_token_rate,
756
+ offset=duration_offset,
757
+ frame_resolution=frame_resolution
758
+ )
759
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
760
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
761
+
762
+ def get_backbone_input(
763
+ self, target_length: int, content: torch.Tensor,
764
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
765
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
766
+ ):
767
+ # TODO compatility for 2D spectrogram VAE
768
+ time_aligned_content = trim_or_pad_length(
769
+ time_aligned_content, target_length, 1
770
+ )
771
+ length_aligned_content = trim_or_pad_length(
772
+ length_aligned_content, target_length, 1
773
+ )
774
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
775
+ # length_aligned_content: from aligned input (f0/energy)
776
+ time_aligned_content = time_aligned_content + length_aligned_content
777
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
778
+ time_aligned_content.dtype
779
+ )
780
+
781
+ context = content
782
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
783
+ # only use the first dummy non time aligned embedding
784
+ context_mask = content_mask.detach().clone()
785
+ context_mask[is_time_aligned, 1:] = False
786
+
787
+ # truncate dummy non time aligned context
788
+ if is_time_aligned.sum().item() < content.size(0):
789
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
790
+ else:
791
+ trunc_nta_length = content.size(1)
792
+ context = context[:, :trunc_nta_length]
793
+ context_mask = context_mask[:, :trunc_nta_length]
794
+
795
+ return context, context_mask, time_aligned_content
796
+
797
+ def forward(
798
+ self,
799
+ content: list[Any],
800
+ duration: Sequence[float],
801
+ task: list[str],
802
+ is_time_aligned: Sequence[bool],
803
+ waveform: torch.Tensor,
804
+ waveform_lengths: torch.Tensor,
805
+ instruction: torch.Tensor,
806
+ instruction_lengths: torch.Tensor,
807
+ loss_reduce: bool = True,
808
+ **kwargs
809
+ ):
810
+ device = self.dummy_param.device
811
+ loss_reduce = self.training or (loss_reduce and not self.training)
812
+
813
+ self.autoencoder.eval()
814
+ with torch.no_grad():
815
+ latent, latent_mask = self.autoencoder.encode(
816
+ waveform.unsqueeze(1), waveform_lengths
817
+ )
818
+
819
+ (
820
+ content, content_mask, global_duration_pred, local_duration_pred,
821
+ length_aligned_content
822
+ ) = self.encode_content_with_instruction(
823
+ content, task, device, instruction, instruction_lengths
824
+ )
825
+
826
+ # truncate unused non time aligned duration prediction
827
+ if is_time_aligned.sum() > 0:
828
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
829
+ else:
830
+ trunc_ta_length = content.size(1)
831
+
832
+ # duration loss
833
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
834
+ ta_content_mask = content_mask[:, :trunc_ta_length]
835
+ local_duration_loss = self.get_local_duration_loss(
836
+ duration,
837
+ local_duration_pred,
838
+ ta_content_mask,
839
+ is_time_aligned,
840
+ reduce=loss_reduce
841
+ )
842
+
843
+ global_duration_loss = self.get_global_duration_loss(
844
+ global_duration_pred, latent_mask, reduce=loss_reduce
845
+ )
846
+
847
+ # --------------------------------------------------------------------
848
+ # prepare latent and noise
849
+ # --------------------------------------------------------------------
850
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
851
+ latent, training=self.training
852
+ )
853
+
854
+ # --------------------------------------------------------------------
855
+ # duration adapter
856
+ # --------------------------------------------------------------------
857
+ if is_time_aligned.sum() == 0 and \
858
+ duration.size(1) < content_mask.size(1):
859
+ duration = F.pad(
860
+ duration, (0, content_mask.size(1) - duration.size(1))
861
+ )
862
+ time_aligned_content, _ = self.expand_by_duration(
863
+ x=content[:, :trunc_ta_length],
864
+ content_mask=ta_content_mask,
865
+ local_duration=duration,
866
+ )
867
+
868
+ # --------------------------------------------------------------------
869
+ # prepare input to the backbone
870
+ # --------------------------------------------------------------------
871
+ # TODO compatility for 2D spectrogram VAE
872
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
873
+ context, context_mask, time_aligned_content = self.get_backbone_input(
874
+ latent_length, content, content_mask, time_aligned_content,
875
+ length_aligned_content, is_time_aligned
876
+ )
877
+
878
+ # --------------------------------------------------------------------
879
+ # classifier free guidance
880
+ # --------------------------------------------------------------------
881
+ if self.training and self.classifier_free_guidance:
882
+ mask_indices = [
883
+ k for k in range(len(waveform))
884
+ if random.random() < self.cfg_drop_ratio
885
+ ]
886
+ if len(mask_indices) > 0:
887
+ context[mask_indices] = 0
888
+ time_aligned_content[mask_indices] = 0
889
+
890
+ pred: torch.Tensor = self.backbone(
891
+ x=noisy_latent,
892
+ x_mask=latent_mask,
893
+ timesteps=timesteps,
894
+ context=context,
895
+ context_mask=context_mask,
896
+ time_aligned_context=time_aligned_content,
897
+ )
898
+ pred = pred.transpose(1, self.autoencoder.time_dim)
899
+ target = target.transpose(1, self.autoencoder.time_dim)
900
+ diff_loss = F.mse_loss(pred, target, reduction="none")
901
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
902
+ return {
903
+ "diff_loss": diff_loss,
904
+ "local_duration_loss": local_duration_loss,
905
+ "global_duration_loss": global_duration_loss,
906
+ }
907
+
908
+ def inference(
909
+ self,
910
+ content: list[Any],
911
+ task: list[str],
912
+ is_time_aligned: Sequence[bool],
913
+ instruction: torch.Tensor,
914
+ instruction_lengths: Sequence[int],
915
+ num_steps: int = 20,
916
+ sway_sampling_coef: float | None = -1.0,
917
+ guidance_scale: float = 3.0,
918
+ disable_progress: bool = True,
919
+ use_gt_duration: bool = False,
920
+ **kwargs
921
+ ):
922
+ device = self.dummy_param.device
923
+ classifier_free_guidance = guidance_scale > 1.0
924
+
925
+ (
926
+ content, content_mask, global_duration_pred, local_duration_pred,
927
+ length_aligned_content
928
+ ) = self.encode_content_with_instruction(
929
+ content, task, device, instruction, instruction_lengths
930
+ )
931
+ # print("content std: ", content.std())
932
+ batch_size = content.size(0)
933
+
934
+ # truncate dummy time aligned duration prediction
935
+ is_time_aligned = torch.as_tensor(is_time_aligned)
936
+ if is_time_aligned.sum() > 0:
937
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
938
+ else:
939
+ trunc_ta_length = content.size(1)
940
+
941
+ # prepare local duration
942
+ local_duration = self.prepare_local_duration(
943
+ local_duration_pred, content_mask
944
+ )
945
+ local_duration = local_duration[:, :trunc_ta_length]
946
+ # use ground truth duration
947
+ if use_gt_duration and "duration" in kwargs:
948
+ local_duration = torch.as_tensor(kwargs["duration"]).to(device)
949
+
950
+ # prepare global duration
951
+ global_duration = self.prepare_global_duration(
952
+ global_duration_pred, local_duration, is_time_aligned
953
+ )
954
+
955
+ # --------------------------------------------------------------------
956
+ # duration adapter
957
+ # --------------------------------------------------------------------
958
+ time_aligned_content, latent_mask = self.expand_by_duration(
959
+ x=content[:, :trunc_ta_length],
960
+ content_mask=content_mask[:, :trunc_ta_length],
961
+ local_duration=local_duration,
962
+ global_duration=global_duration,
963
+ )
964
+
965
+ context, context_mask, time_aligned_content = self.get_backbone_input(
966
+ target_length=time_aligned_content.size(1),
967
+ content=content,
968
+ content_mask=content_mask,
969
+ time_aligned_content=time_aligned_content,
970
+ length_aligned_content=length_aligned_content,
971
+ is_time_aligned=is_time_aligned
972
+ )
973
+
974
+ # --------------------------------------------------------------------
975
+ # prepare unconditional input
976
+ # --------------------------------------------------------------------
977
+ if classifier_free_guidance:
978
+ uncond_time_aligned_content = torch.zeros_like(
979
+ time_aligned_content
980
+ )
981
+ uncond_context = torch.zeros_like(context)
982
+ uncond_context_mask = context_mask.detach().clone()
983
+ time_aligned_content = torch.cat([
984
+ uncond_time_aligned_content, time_aligned_content
985
+ ])
986
+ context = torch.cat([uncond_context, context])
987
+ context_mask = torch.cat([uncond_context_mask, context_mask])
988
+ latent_mask = torch.cat([
989
+ latent_mask, latent_mask.detach().clone()
990
+ ])
991
+
992
+ # --------------------------------------------------------------------
993
+ # prepare input to the backbone
994
+ # --------------------------------------------------------------------
995
+ latent_length = latent_mask.sum(1).max().item()
996
+ latent_shape = tuple(
997
+ latent_length if dim is None else dim
998
+ for dim in self.autoencoder.latent_shape
999
+ )
1000
+ shape = (batch_size, *latent_shape)
1001
+ latent = randn_tensor(
1002
+ shape, generator=None, device=device, dtype=content.dtype
1003
+ )
1004
+
1005
+ if not sway_sampling_coef:
1006
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
1007
+ else:
1008
+ t = torch.linspace(0, 1, num_steps + 1)
1009
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
1010
+ sigmas = 1 - t
1011
+ timesteps, num_steps = self.retrieve_timesteps(
1012
+ num_steps, device, timesteps=None, sigmas=sigmas
1013
+ )
1014
+ latent = self.iterative_denoise(
1015
+ latent=latent,
1016
+ timesteps=timesteps,
1017
+ num_steps=num_steps,
1018
+ verbose=not disable_progress,
1019
+ cfg=classifier_free_guidance,
1020
+ cfg_scale=guidance_scale,
1021
+ backbone_input={
1022
+ "x_mask": latent_mask,
1023
+ "context": context,
1024
+ "context_mask": context_mask,
1025
+ "time_aligned_context": time_aligned_content,
1026
+ }
1027
+ )
1028
+
1029
+ waveform = self.autoencoder.decode(latent)
1030
+ return waveform
1031
+
1032
+
1033
+ class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching):
1034
+ def get_backbone_input(
1035
+ self, target_length: int, content: torch.Tensor,
1036
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
1037
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
1038
+ ):
1039
+ # TODO compatility for 2D spectrogram VAE
1040
+ time_aligned_content = trim_or_pad_length(
1041
+ time_aligned_content, target_length, 1
1042
+ )
1043
+ context_length = min(content.size(1), time_aligned_content.size(1))
1044
+ time_aligned_content[~is_time_aligned, :context_length] = content[
1045
+ ~is_time_aligned, :context_length]
1046
+ length_aligned_content = trim_or_pad_length(
1047
+ length_aligned_content, target_length, 1
1048
+ )
1049
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
1050
+ # length_aligned_content: from aligned input (f0/energy)
1051
+ time_aligned_content = time_aligned_content + length_aligned_content
1052
+
1053
+ context = content
1054
+ context_mask = content_mask.detach().clone()
1055
+
1056
+ return context, context_mask, time_aligned_content
1057
+
1058
+
1059
+ class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching):
1060
+ def get_backbone_input(
1061
+ self, target_length: int, content: torch.Tensor,
1062
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
1063
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
1064
+ ):
1065
+ # TODO compatility for 2D spectrogram VAE
1066
+ time_aligned_content = trim_or_pad_length(
1067
+ time_aligned_content, target_length, 1
1068
+ )
1069
+ length_aligned_content = trim_or_pad_length(
1070
+ length_aligned_content, target_length, 1
1071
+ )
1072
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
1073
+ # length_aligned_content: from aligned input (f0/energy)
1074
+ time_aligned_content = time_aligned_content + length_aligned_content
1075
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
1076
+ time_aligned_content.dtype
1077
+ )
1078
+
1079
+ context = content
1080
+ context_mask = content_mask.detach().clone()
1081
+
1082
+ return context, context_mask, time_aligned_content
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.26.0
2
+ # --- Core Framework (Pinned Versions) ---
3
+ torch==2.5.1
4
+ torchvision==0.20.1
5
+ torchaudio==2.5.1
6
+
7
+ # --- Deep Learning & Utilities ---
8
+ diffusers
9
+ transformers
10
+ accelerate
11
+ einops
12
+ alias_free_torch
13
+ tqdm
14
+ torchdata
15
+
16
+ # --- Config & Data ---
17
+ hydra-core
18
+ omegaconf
19
+ h5py
20
+
21
+ # --- Audio ---
22
+ librosa
23
+ soundfile
24
+
25
+ # --- Logging ---
26
+ wandb
27
+ tensorboard
28
+ swanlab
stabilityai/stable-diffusion-2-1/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.8.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "v_prediction",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }
utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.7 kB). View file
 
utils/__pycache__/torch_utilities.cpython-310.pyc ADDED
Binary file (8.33 kB). View file
 
utils/accelerate_utilities.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import Accelerator
2
+
3
+
4
+ class AcceleratorSaveTrainableParams(Accelerator):
5
+ def get_state_dict(self, model, unwrap=True):
6
+ state_dict = super().get_state_dict(model, unwrap)
7
+ if hasattr(model, "param_names_to_save"):
8
+ param_names_to_save = model.param_names_to_save
9
+ return {
10
+ k: v
11
+ for k, v in state_dict.items() if k in param_names_to_save
12
+ }
13
+ return state_dict
utils/audio.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+
5
+
6
+ class PadCrop(nn.Module):
7
+ def __init__(self, n_samples, randomize=True):
8
+ super().__init__()
9
+ self.n_samples = n_samples
10
+ self.randomize = randomize
11
+
12
+ def __call__(self, signal):
13
+ n, s = signal.shape
14
+ start = 0 if (
15
+ not self.randomize
16
+ ) else torch.randint(0,
17
+ max(0, s - self.n_samples) + 1, []).item()
18
+ end = start + self.n_samples
19
+ output = signal.new_zeros([n, self.n_samples])
20
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
21
+ return output
22
+
23
+
24
+ def set_audio_channels(audio, target_channels):
25
+ if target_channels == 1:
26
+ # Convert to mono
27
+ audio = audio.mean(1, keepdim=True)
28
+ elif target_channels == 2:
29
+ # Convert to stereo
30
+ if audio.shape[1] == 1:
31
+ audio = audio.repeat(1, 2, 1)
32
+ elif audio.shape[1] > 2:
33
+ audio = audio[:, :2, :]
34
+ return audio
35
+
36
+
37
+ def prepare_audio(
38
+ audio, in_sr, target_sr, target_length, target_channels, device
39
+ ):
40
+
41
+ audio = audio.to(device)
42
+
43
+ if in_sr != target_sr:
44
+ resample_tf = torchaudio.transforms.Resample(in_sr,
45
+ target_sr).to(device)
46
+ audio = resample_tf(audio)
47
+
48
+ audio = PadCrop(target_length, randomize=False)(audio)
49
+
50
+ # Add batch dimension
51
+ if audio.dim() == 1:
52
+ audio = audio.unsqueeze(0).unsqueeze(0)
53
+ elif audio.dim() == 2:
54
+ audio = audio.unsqueeze(0)
55
+
56
+ audio = set_audio_channels(audio, target_channels)
57
+
58
+ return audio
utils/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+ import os
4
+
5
+ import hydra
6
+ import omegaconf
7
+ from omegaconf import OmegaConf
8
+
9
+
10
+ def multiply(*args):
11
+ result = 1
12
+ for arg in args:
13
+ result *= arg
14
+ return result
15
+
16
+
17
+ def get_pitch_downsample_ratio(
18
+ autoencoder_config: dict, pitch_frame_resolution: float
19
+ ):
20
+ latent_frame_resolution = autoencoder_config[
21
+ "downsampling_ratio"] / autoencoder_config["sample_rate"]
22
+ return round(latent_frame_resolution / pitch_frame_resolution)
23
+
24
+
25
+ def register_omegaconf_resolvers() -> None:
26
+ """
27
+ Register custom resolver for hydra configs, which can be used in YAML
28
+ files for dynamically setting values
29
+ """
30
+ OmegaConf.clear_resolvers()
31
+ OmegaConf.register_new_resolver("len", len, replace=True)
32
+ OmegaConf.register_new_resolver("multiply", multiply, replace=True)
33
+ OmegaConf.register_new_resolver(
34
+ "get_pitch_downsample_ratio", get_pitch_downsample_ratio, replace=True
35
+ )
36
+
37
+
38
+ def generate_config_from_command_line_overrides(
39
+ config_file: str | Path
40
+ ) -> omegaconf.DictConfig:
41
+ register_omegaconf_resolvers()
42
+
43
+ config_file = Path(config_file).resolve()
44
+ config_name = config_file.name.__str__()
45
+ config_path = config_file.parent.__str__()
46
+ config_path = os.path.relpath(config_path, Path(__file__).resolve().parent)
47
+
48
+ overrides = sys.argv[1:]
49
+ with hydra.initialize(version_base=None, config_path=config_path):
50
+ config = hydra.compose(config_name=config_name, overrides=overrides)
51
+ omegaconf.OmegaConf.resolve(config)
52
+
53
+ return config
utils/diffsinger_utilities.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import six
2
+ from pathlib import Path
3
+ import re
4
+ import json
5
+ from collections import OrderedDict
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import librosa
10
+ import torch
11
+
12
+ PAD = "<pad>"
13
+ EOS = "<EOS>"
14
+ UNK = "<UNK>"
15
+ SEG = "|"
16
+ RESERVED_TOKENS = [PAD, EOS, UNK]
17
+ NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
18
+ PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
19
+ EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
20
+ UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
21
+
22
+ F0_BIN = 256
23
+ F0_MAX = 1100.0
24
+ F0_MIN = 50.0
25
+ F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
26
+ F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
27
+
28
+
29
+ def f0_to_coarse(f0):
30
+ is_torch = isinstance(f0, torch.Tensor)
31
+ f0_mel = 1127 * (1 + f0 /
32
+ 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
33
+ f0_mel[f0_mel > 0
34
+ ] = (f0_mel[f0_mel > 0] -
35
+ F0_MEL_MIN) * (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) + 1
36
+
37
+ f0_mel[f0_mel <= 1] = 1
38
+ f0_mel[f0_mel > F0_BIN - 1] = F0_BIN - 1
39
+ f0_coarse = (f0_mel +
40
+ 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
41
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
42
+ f0_coarse.max(), f0_coarse.min()
43
+ )
44
+ return f0_coarse
45
+
46
+
47
+ def norm_f0(
48
+ f0: Union[np.ndarray, torch.Tensor],
49
+ uv: Union[None, np.ndarray],
50
+ f0_mean: float,
51
+ f0_std: float,
52
+ pitch_norm: str = "log",
53
+ use_uv: bool = True
54
+ ):
55
+ is_torch = isinstance(f0, torch.Tensor)
56
+ if pitch_norm == 'standard':
57
+ f0 = (f0 - f0_mean) / f0_std
58
+ if pitch_norm == 'log':
59
+ f0 = torch.log2(f0) if is_torch else np.log2(f0)
60
+ if uv is not None and use_uv:
61
+ f0[uv > 0] = 0
62
+ return f0
63
+
64
+
65
+ def norm_interp_f0(
66
+ f0: Union[np.ndarray, torch.Tensor],
67
+ f0_mean: float,
68
+ f0_std: float,
69
+ pitch_norm: str = "log",
70
+ use_uv: bool = True
71
+ ):
72
+ is_torch = isinstance(f0, torch.Tensor)
73
+ if is_torch:
74
+ device = f0.device
75
+ f0 = f0.data.cpu().numpy()
76
+ uv = f0 == 0
77
+ f0 = norm_f0(f0, uv, f0_mean, f0_std, pitch_norm, use_uv)
78
+ if sum(uv) == len(f0):
79
+ f0[uv] = 0
80
+ elif sum(uv) > 0:
81
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
82
+ uv = torch.as_tensor(uv).float()
83
+ f0 = torch.as_tensor(f0).float()
84
+ if is_torch:
85
+ f0 = f0.to(device)
86
+ return f0, uv
87
+
88
+
89
+ def denorm_f0(
90
+ f0,
91
+ uv,
92
+ pitch_norm="log",
93
+ f0_mean=None,
94
+ f0_std=None,
95
+ pitch_padding=None,
96
+ min=None,
97
+ max=None,
98
+ use_uv=True
99
+ ):
100
+ if pitch_norm == 'standard':
101
+ f0 = f0 * f0_std + f0_mean
102
+ if pitch_norm == 'log':
103
+ f0 = 2**f0
104
+ if min is not None:
105
+ f0 = f0.clamp(min=min)
106
+ if max is not None:
107
+ f0 = f0.clamp(max=max)
108
+ if uv is not None and use_uv:
109
+ f0[uv > 0] = 0
110
+ if pitch_padding is not None:
111
+ f0[pitch_padding] = 0
112
+ return f0
113
+
114
+
115
+ def librosa_pad_lr(x, fshift, pad_sides=1):
116
+ '''compute right padding (final frame) or both sides padding (first and final frames)
117
+ '''
118
+ assert pad_sides in (1, 2)
119
+ # return int(fsize // 2)
120
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
121
+ if pad_sides == 1:
122
+ return 0, pad
123
+ else:
124
+ return pad // 2, pad // 2 + pad % 2
125
+
126
+
127
+ def get_pitch(
128
+ wav_file: Union[str, Path], sample_rate: int, frame_shift: float
129
+ ):
130
+ import parselmouth
131
+ hop_size = int(frame_shift * sample_rate)
132
+ wav, _ = librosa.core.load(wav_file, sr=sample_rate)
133
+ # l_pad, r_pad = librosa_pad_lr(wav, hop_size, 1)
134
+ # wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
135
+
136
+ latent_length = wav.shape[0] // hop_size
137
+ f0_min = 80
138
+ f0_max = 750
139
+ pad_size = 4
140
+
141
+ f0 = parselmouth.Sound(wav, sample_rate).to_pitch_ac(
142
+ time_step=frame_shift,
143
+ voicing_threshold=0.6,
144
+ pitch_floor=f0_min,
145
+ pitch_ceiling=f0_max
146
+ ).selected_array['frequency']
147
+ delta_l = latent_length - len(f0)
148
+ if delta_l > 0:
149
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
150
+ pitch_coarse = f0_to_coarse(f0)
151
+ return f0, pitch_coarse
152
+
153
+
154
+ def remove_empty_lines(text):
155
+ """remove empty lines"""
156
+ assert (len(text) > 0)
157
+ assert (isinstance(text, list))
158
+ text = [t.strip() for t in text]
159
+ if "" in text:
160
+ text.remove("")
161
+ return text
162
+
163
+
164
+ def is_sil_phoneme(p):
165
+ return not p[0].isalpha()
166
+
167
+
168
+ def strip_ids(ids, ids_to_strip):
169
+ """Strip ids_to_strip from the end ids."""
170
+ ids = list(ids)
171
+ while ids and ids[-1] in ids_to_strip:
172
+ ids.pop()
173
+ return ids
174
+
175
+
176
+ class TextEncoder(object):
177
+ """Base class for converting from ints to/from human readable strings."""
178
+ def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
179
+ self._num_reserved_ids = num_reserved_ids
180
+
181
+ @property
182
+ def num_reserved_ids(self):
183
+ return self._num_reserved_ids
184
+
185
+ def encode(self, s):
186
+ """Transform a human-readable string into a sequence of int ids.
187
+
188
+ The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
189
+ num_reserved_ids) are reserved.
190
+
191
+ EOS is not appended.
192
+
193
+ Args:
194
+ s: human-readable string to be converted.
195
+
196
+ Returns:
197
+ ids: list of integers
198
+ """
199
+ return [int(w) + self._num_reserved_ids for w in s.split()]
200
+
201
+ def decode(self, ids, strip_extraneous=False):
202
+ """Transform a sequence of int ids into a human-readable string.
203
+
204
+ EOS is not expected in ids.
205
+
206
+ Args:
207
+ ids: list of integers to be converted.
208
+ strip_extraneous: bool, whether to strip off extraneous tokens
209
+ (EOS and PAD).
210
+
211
+ Returns:
212
+ s: human-readable string.
213
+ """
214
+ if strip_extraneous:
215
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
216
+ return " ".join(self.decode_list(ids))
217
+
218
+ def decode_list(self, ids):
219
+ """Transform a sequence of int ids into a their string versions.
220
+
221
+ This method supports transforming individual input/output ids to their
222
+ string versions so that sequence to/from text conversions can be visualized
223
+ in a human readable format.
224
+
225
+ Args:
226
+ ids: list of integers to be converted.
227
+
228
+ Returns:
229
+ strs: list of human-readable string.
230
+ """
231
+ decoded_ids = []
232
+ for id_ in ids:
233
+ if 0 <= id_ < self._num_reserved_ids:
234
+ decoded_ids.append(RESERVED_TOKENS[int(id_)])
235
+ else:
236
+ decoded_ids.append(id_ - self._num_reserved_ids)
237
+ return [str(d) for d in decoded_ids]
238
+
239
+ @property
240
+ def vocab_size(self):
241
+ raise NotImplementedError()
242
+
243
+
244
+ class TokenTextEncoder(TextEncoder):
245
+ """Encoder based on a user-supplied vocabulary (file or list)."""
246
+ def __init__(
247
+ self,
248
+ vocab_filename,
249
+ reverse=False,
250
+ vocab_list=None,
251
+ replace_oov=None,
252
+ num_reserved_ids=NUM_RESERVED_TOKENS
253
+ ):
254
+ """Initialize from a file or list, one token per line.
255
+
256
+ Handling of reserved tokens works as follows:
257
+ - When initializing from a list, we add reserved tokens to the vocab.
258
+ - When initializing from a file, we do not add reserved tokens to the vocab.
259
+ - When saving vocab files, we save reserved tokens to the file.
260
+
261
+ Args:
262
+ vocab_filename: If not None, the full filename to read vocab from. If this
263
+ is not None, then vocab_list should be None.
264
+ reverse: Boolean indicating if tokens should be reversed during encoding
265
+ and decoding.
266
+ vocab_list: If not None, a list of elements of the vocabulary. If this is
267
+ not None, then vocab_filename should be None.
268
+ replace_oov: If not None, every out-of-vocabulary token seen when
269
+ encoding will be replaced by this string (which must be in vocab).
270
+ num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
271
+ """
272
+ super(TokenTextEncoder,
273
+ self).__init__(num_reserved_ids=num_reserved_ids)
274
+ self._reverse = reverse
275
+ self._replace_oov = replace_oov
276
+ if vocab_filename:
277
+ self._init_vocab_from_file(vocab_filename)
278
+ else:
279
+ assert vocab_list is not None
280
+ self._init_vocab_from_list(vocab_list)
281
+ self.pad_index = self._token_to_id[PAD]
282
+ self.eos_index = self._token_to_id[EOS]
283
+ self.unk_index = self._token_to_id[UNK]
284
+ self.seg_index = self._token_to_id[
285
+ SEG] if SEG in self._token_to_id else self.eos_index
286
+
287
+ def encode(self, s):
288
+ """Converts a space-separated string of tokens to a list of ids."""
289
+ sentence = s
290
+ tokens = sentence.strip().split()
291
+ if self._replace_oov is not None:
292
+ tokens = [
293
+ t if t in self._token_to_id else self._replace_oov
294
+ for t in tokens
295
+ ]
296
+ ret = [self._token_to_id[tok] for tok in tokens]
297
+ return ret[::-1] if self._reverse else ret
298
+
299
+ def decode(self, ids, strip_eos=False, strip_padding=False):
300
+ if strip_padding and self.pad() in list(ids):
301
+ pad_pos = list(ids).index(self.pad())
302
+ ids = ids[:pad_pos]
303
+ if strip_eos and self.eos() in list(ids):
304
+ eos_pos = list(ids).index(self.eos())
305
+ ids = ids[:eos_pos]
306
+ return " ".join(self.decode_list(ids))
307
+
308
+ def decode_list(self, ids):
309
+ seq = reversed(ids) if self._reverse else ids
310
+ return [self._safe_id_to_token(i) for i in seq]
311
+
312
+ @property
313
+ def vocab_size(self):
314
+ return len(self._id_to_token)
315
+
316
+ def __len__(self):
317
+ return self.vocab_size
318
+
319
+ def _safe_id_to_token(self, idx):
320
+ return self._id_to_token.get(idx, "ID_%d" % idx)
321
+
322
+ def _init_vocab_from_file(self, filename):
323
+ """Load vocab from a file.
324
+
325
+ Args:
326
+ filename: The file to load vocabulary from.
327
+ """
328
+ with open(filename) as f:
329
+ tokens = [token.strip() for token in f.readlines()]
330
+
331
+ def token_gen():
332
+ for token in tokens:
333
+ yield token
334
+
335
+ self._init_vocab(token_gen(), add_reserved_tokens=False)
336
+
337
+ def _init_vocab_from_list(self, vocab_list):
338
+ """Initialize tokens from a list of tokens.
339
+
340
+ It is ok if reserved tokens appear in the vocab list. They will be
341
+ removed. The set of tokens in vocab_list should be unique.
342
+
343
+ Args:
344
+ vocab_list: A list of tokens.
345
+ """
346
+ def token_gen():
347
+ for token in vocab_list:
348
+ if token not in RESERVED_TOKENS:
349
+ yield token
350
+
351
+ self._init_vocab(token_gen())
352
+
353
+ def _init_vocab(self, token_generator, add_reserved_tokens=True):
354
+ """Initialize vocabulary with tokens from token_generator."""
355
+
356
+ self._id_to_token = {}
357
+ non_reserved_start_index = 0
358
+
359
+ if add_reserved_tokens:
360
+ self._id_to_token.update(enumerate(RESERVED_TOKENS))
361
+ non_reserved_start_index = len(RESERVED_TOKENS)
362
+
363
+ self._id_to_token.update(
364
+ enumerate(token_generator, start=non_reserved_start_index)
365
+ )
366
+
367
+ # _token_to_id is the reverse of _id_to_token
368
+ self._token_to_id = dict(
369
+ (v, k) for k, v in six.iteritems(self._id_to_token)
370
+ )
371
+
372
+ def pad(self):
373
+ return self.pad_index
374
+
375
+ def eos(self):
376
+ return self.eos_index
377
+
378
+ def unk(self):
379
+ return self.unk_index
380
+
381
+ def seg(self):
382
+ return self.seg_index
383
+
384
+ def store_to_file(self, filename):
385
+ """Write vocab file to disk.
386
+
387
+ Vocab files have one token per line. The file ends in a newline. Reserved
388
+ tokens are written to the vocab file as well.
389
+
390
+ Args:
391
+ filename: Full path of the file to store the vocab to.
392
+ """
393
+ with open(filename, "w") as f:
394
+ for i in range(len(self._id_to_token)):
395
+ f.write(self._id_to_token[i] + "\n")
396
+
397
+ def sil_phonemes(self):
398
+ return [p for p in self._id_to_token.values() if not p[0].isalpha()]
399
+
400
+
401
+ class TextGrid(object):
402
+ def __init__(self, text):
403
+ text = remove_empty_lines(text)
404
+ self.text = text
405
+ self.line_count = 0
406
+ self._get_type()
407
+ self._get_time_intval()
408
+ self._get_size()
409
+ self.tier_list = []
410
+ self._get_item_list()
411
+
412
+ def _extract_pattern(self, pattern, inc):
413
+ """
414
+ Parameters
415
+ ----------
416
+ pattern : regex to extract pattern
417
+ inc : increment of line count after extraction
418
+ Returns
419
+ -------
420
+ group : extracted info
421
+ """
422
+ try:
423
+ group = re.match(pattern, self.text[self.line_count]).group(1)
424
+ self.line_count += inc
425
+ except AttributeError:
426
+ raise ValueError(
427
+ "File format error at line %d:%s" %
428
+ (self.line_count, self.text[self.line_count])
429
+ )
430
+ return group
431
+
432
+ def _get_type(self):
433
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
434
+
435
+ def _get_time_intval(self):
436
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
437
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
438
+
439
+ def _get_size(self):
440
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
441
+
442
+ def _get_item_list(self):
443
+ """Only supports IntervalTier currently"""
444
+ for itemIdx in range(1, self.size + 1):
445
+ tier = OrderedDict()
446
+ item_list = []
447
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
448
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
449
+ if tier_class != "IntervalTier":
450
+ raise NotImplementedError(
451
+ "Only IntervalTier class is supported currently"
452
+ )
453
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
454
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
455
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
456
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
457
+ for i in range(int(tier_size)):
458
+ item = OrderedDict()
459
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
460
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
461
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
462
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
463
+ item_list.append(item)
464
+ tier["idx"] = tier_idx
465
+ tier["class"] = tier_class
466
+ tier["name"] = tier_name
467
+ tier["xmin"] = tier_xmin
468
+ tier["xmax"] = tier_xmax
469
+ tier["size"] = tier_size
470
+ tier["items"] = item_list
471
+ self.tier_list.append(tier)
472
+
473
+ def toJson(self):
474
+ _json = OrderedDict()
475
+ _json["file_type"] = self.file_type
476
+ _json["xmin"] = self.xmin
477
+ _json["xmax"] = self.xmax
478
+ _json["size"] = self.size
479
+ _json["tiers"] = self.tier_list
480
+ return json.dumps(_json, ensure_ascii=False, indent=2)
481
+
482
+
483
+ def read_duration_from_textgrid(
484
+ textgrid_path: Union[str, Path],
485
+ phoneme: str,
486
+ utterance_duration: float,
487
+ ):
488
+ ph_list = phoneme.split(" ")
489
+ with open(textgrid_path, "r") as f:
490
+ textgrid = f.readlines()
491
+ textgrid = remove_empty_lines(textgrid)
492
+ textgrid = TextGrid(textgrid)
493
+ textgrid = json.loads(textgrid.toJson())
494
+
495
+ split = np.ones(len(ph_list) + 1, np.float32) * -1
496
+ tg_idx = 0
497
+ ph_idx = 0
498
+ tg_align = [x for x in textgrid['tiers'][-1]['items']]
499
+ tg_align_ = []
500
+ for x in tg_align:
501
+ x['xmin'] = float(x['xmin'])
502
+ x['xmax'] = float(x['xmax'])
503
+ if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC', '<SP>', '<AP>']:
504
+ x['text'] = ''
505
+ if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
506
+ tg_align_[-1]['xmax'] = x['xmax']
507
+ continue
508
+ tg_align_.append(x)
509
+ tg_align = tg_align_
510
+ tg_len = len([x for x in tg_align if x['text'] != ''])
511
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
512
+ assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, textgrid_path)
513
+ while tg_idx < len(tg_align) or ph_idx < len(ph_list):
514
+ if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
515
+ split[ph_idx] = 1e8
516
+ ph_idx += 1
517
+ continue
518
+ x = tg_align[tg_idx]
519
+ if x['text'] == '' and ph_idx == len(ph_list):
520
+ tg_idx += 1
521
+ continue
522
+ assert ph_idx < len(ph_list), (
523
+ tg_len, ph_len, tg_align, ph_list, textgrid_path
524
+ )
525
+
526
+ ph = ph_list[ph_idx]
527
+ if x['text'] == '' and not is_sil_phoneme(ph):
528
+ assert False, (ph_list, tg_align)
529
+ if x['text'] != '' and is_sil_phoneme(ph):
530
+ ph_idx += 1
531
+ else:
532
+ assert (x['text'] == '' and is_sil_phoneme(ph)) \
533
+ or x['text'].lower() == ph.lower() \
534
+ or x['text'].lower() == 'sil', (x['text'], ph)
535
+ split[ph_idx] = x['xmin']
536
+ if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(
537
+ ph_list[ph_idx - 1]
538
+ ):
539
+ split[ph_idx - 1] = split[ph_idx]
540
+ ph_idx += 1
541
+ tg_idx += 1
542
+ assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
543
+ assert ph_idx >= len(ph_list) - 1, (
544
+ ph_idx, ph_list, len(ph_list), [x['text']
545
+ for x in tg_align], textgrid_path
546
+ )
547
+
548
+ split[0] = 0
549
+ split[-1] = utterance_duration
550
+ duration = np.diff(split)
551
+ return duration
utils/general.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Union, Dict
4
+ from pathlib import Path
5
+ import os
6
+
7
+ MAX_FILE_NAME_LENGTH = 100
8
+
9
+
10
+ def read_jsonl_to_mapping(
11
+ jsonl_file: Union[str, Path],
12
+ key_col: str,
13
+ value_col: str,
14
+ base_path=None
15
+ ) -> Dict[str, str]:
16
+ """
17
+ Read two columns, indicated by `key_col` and `value_col`, from the
18
+ given jsonl file to return the mapping dict
19
+ TODO handle duplicate keys
20
+ """
21
+ mapping = {}
22
+ with open(jsonl_file, 'r') as file:
23
+ for line in file.readlines():
24
+ data = json.loads(line.strip())
25
+ key = data[key_col]
26
+ value = data[value_col]
27
+ if base_path:
28
+ value = os.path.join(base_path, value)
29
+ mapping[key] = value
30
+ return mapping
31
+
32
+
33
+ def sanitize_filename(name: str, max_len: int = MAX_FILE_NAME_LENGTH) -> str:
34
+ """
35
+ Clean and truncate a string to make it a valid and safe filename.
36
+ """
37
+ name = re.sub(r'[\\/*?:"<>|]', '_', name)
38
+ name = name.replace('/', '_')
39
+ max_len = min(len(name), max_len)
40
+ return name[:max_len]
41
+
42
+
43
+ def transform_gen_fn_to_id(audio_file: Path, task: str) -> str:
44
+ if task == "svs":
45
+ audio_id = audio_file.stem.split("_")[0]
46
+ elif task == "sr":
47
+ audio_id = audio_file.stem
48
+ elif task == "tta":
49
+ audio_id = audio_file.stem[:11]
50
+ # audio_id = audio_file.stem[:12] + '.wav'
51
+ elif task == "ttm":
52
+ audio_id = audio_file.stem[:11]
53
+ # audio_id = audio_file.stem[:12] + '.wav'
54
+ elif task == "v2a":
55
+ audio_id = audio_file.stem.rsplit("_", 1)[0] + ".mp4"
56
+ else:
57
+ audio_id = audio_file.stem
58
+ return audio_id
59
+
60
+
61
+ def audio_dir_to_mapping(audio_dir: str | Path, task: str) -> dict:
62
+ mapping = {}
63
+ audio_dir = Path(audio_dir)
64
+ audio_files = sorted(audio_dir.iterdir())
65
+ for audio_file in audio_files:
66
+ audio_id = transform_gen_fn_to_id(audio_file, task)
67
+ mapping[audio_id] = str(audio_file.resolve())
68
+ return mapping
utils/logging.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from dataclasses import dataclass
3
+ import logging
4
+
5
+
6
+ @dataclass
7
+ class LoggingLogger:
8
+
9
+ filename: str | Path
10
+ level: str = "INFO"
11
+
12
+ def create_instance(self, ):
13
+ filename = self.filename.__str__()
14
+ formatter = logging.Formatter("[%(asctime)s] - %(message)s")
15
+
16
+ logger = logging.getLogger(__name__ + "." + filename)
17
+ logger.setLevel(getattr(logging, self.level))
18
+
19
+ file_handler = logging.FileHandler(filename)
20
+ file_handler.setFormatter(formatter)
21
+ logger.addHandler(file_handler)
22
+
23
+ return logger
utils/lr_scheduler_utilities.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import math
3
+ import copy
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ def get_warmup_steps(
8
+ dataloader_one_pass_outside_steps: int,
9
+ warmup_steps: int | None = None,
10
+ warmup_epochs: float | None = None,
11
+ epoch_length: int | None = None,
12
+ ) -> int:
13
+ """
14
+ Derive warmup steps according to step number or epoch number.
15
+ If `warmup_steps` is provided, then just return it. Otherwise, derive
16
+ the warmup steps by epoch length and warmup epoch number.
17
+ """
18
+ if warmup_steps is not None:
19
+ return warmup_steps
20
+ else:
21
+ if epoch_length is None:
22
+ epoch_length = dataloader_one_pass_outside_steps
23
+ assert warmup_epochs is not None, "warmup_steps and warmup_epochs cannot be both None"
24
+ return int(epoch_length * warmup_epochs)
25
+
26
+
27
+ def get_dataloader_one_pass_outside_steps(
28
+ train_dataloader: DataLoader,
29
+ num_processes: int = 1,
30
+ ):
31
+ """
32
+ dataloader length after DDP, close to `original_length / gpu_number`
33
+ """
34
+ return math.ceil(len(train_dataloader) / num_processes)
35
+
36
+
37
+ def get_total_training_steps(
38
+ train_dataloader: DataLoader,
39
+ epochs: int,
40
+ num_processes: int = 1,
41
+ epoch_length: int | None = None
42
+ ):
43
+ """
44
+ Calculate the total number of "visible" training steps.
45
+
46
+ If `epoch_length` is provided, it is used as the fixed length for each epoch.
47
+ Otherwise, the function will determine the epoch length from `train_dataloader`.
48
+
49
+ Args:
50
+ train_dataloader:
51
+ Training dataloader object.
52
+ epochs:
53
+ The total number of epochs to run.
54
+ num_processes:
55
+ The number of parallel processes used for distributed training.
56
+ epoch_length:
57
+ A fixed number of training steps for each epoch. Defaults to None.
58
+
59
+ Returns:
60
+ int: The total number of training steps (i.e., `epochs * epoch_length`).
61
+ """
62
+ # `epoch_length` is not None: fixed length for each epoch
63
+ if epoch_length is None:
64
+ # `epoch_length` is the length of DDP-wrapped `train_dataloader`
65
+ epoch_length = get_dataloader_one_pass_outside_steps(
66
+ train_dataloader, num_processes
67
+ )
68
+ return epochs * epoch_length
69
+
70
+
71
+ def get_dataloader_one_pass_steps_inside_accelerator(
72
+ dataloader_one_pass_steps: int, gradient_accumulation_steps: int,
73
+ num_processes: int
74
+ ):
75
+ """
76
+ Calculate the number of "visible" training steps for a single pass over the dataloader
77
+ inside an accelerator, accounting for gradient accumulation and distributed training.
78
+
79
+
80
+ Args:
81
+ dataloader_one_pass_steps:
82
+ The number of steps (batches) in one pass over the dataset.
83
+ gradient_accumulation_steps:
84
+ The number of steps to accumulate gradients before performing a parameter update.
85
+ num_processes:
86
+ The number of parallel processes used for distributed training.
87
+
88
+ Returns:
89
+ int: The total number of "visible" training steps for one pass over the dataset,
90
+ multiplied by the number of processes.
91
+ """
92
+ return math.ceil(
93
+ dataloader_one_pass_steps / gradient_accumulation_steps
94
+ ) * num_processes
95
+
96
+
97
+ def get_steps_inside_accelerator_from_outside_steps(
98
+ outside_steps: int, dataloader_one_pass_outside_steps: int,
99
+ dataloader_one_pass_steps_inside_accelerator: int,
100
+ gradient_accumulation_steps: int, num_processes: int
101
+ ):
102
+ """
103
+ Convert "outside" steps (as observed in wandb logger or similar context)
104
+ to the corresponding number of "inside" steps (for accelerate lr scheduler).
105
+
106
+ Specifically, accelerate lr scheduler call `step()` `num_processes` times for
107
+ every `gradient_accumulation_steps` outside steps.
108
+
109
+ Args:
110
+ outside_steps:
111
+ The total number of steps counted outside accelerate context.
112
+ dataloader_one_pass_outside_steps:
113
+ The number of steps (batches) to complete one pass of the dataloader
114
+ outside accelerate.
115
+ dataloader_one_pass_steps_inside_accelerator:
116
+ The number of `lr_scheduler.step()` calls inside accelerate, calculated via
117
+ `get_dataloader_one_pass_steps_inside_accelerator`.
118
+ gradient_accumulation_steps:
119
+ The number of steps to accumulate gradients.
120
+ num_processes:
121
+ The number of parallel processes (GPUs) used in distributed training.
122
+
123
+ Returns:
124
+ int: The total number of `lr_scheduler.step()` calls inside accelerate that
125
+ correspond to the given `outside_steps`.
126
+ """
127
+ num_dataloader_epochs_passed = outside_steps // dataloader_one_pass_outside_steps
128
+ remaining_outside_steps = outside_steps % dataloader_one_pass_outside_steps
129
+ remaining_inside_accelerator_steps = (
130
+ remaining_outside_steps // gradient_accumulation_steps * num_processes
131
+ )
132
+ # accelerate scheduler call `step()` `num_processes` times every
133
+ # `gradient_accumulation_steps` steps:
134
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/scheduler.py#L76
135
+ total_steps = (
136
+ num_dataloader_epochs_passed*
137
+ dataloader_one_pass_steps_inside_accelerator +
138
+ remaining_inside_accelerator_steps
139
+ )
140
+ return total_steps
141
+
142
+
143
+ def lr_scheduler_param_adapter(
144
+ config_dict: dict[str, Any], num_training_steps: int, num_warmup_steps: int
145
+ ) -> dict[str, Any]:
146
+ target_class = config_dict["_target_"]
147
+ return_dict = copy.deepcopy(config_dict)
148
+ if target_class == "transformers.get_scheduler":
149
+ return_dict.update({
150
+ "num_training_steps": num_training_steps,
151
+ "num_warmup_steps": num_warmup_steps
152
+ })
153
+
154
+ return return_dict
utils/torch_utilities.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Callable
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ logger = logging.Logger(__file__)
10
+
11
+
12
+ def remove_key_prefix_factory(prefix: str = "module."):
13
+ def func(
14
+ model_dict: dict[str, torch.Tensor], state_dict: dict[str,
15
+ torch.Tensor]
16
+ ) -> dict[str, torch.Tensor]:
17
+
18
+ state_dict = {
19
+ key[len(prefix):]: value
20
+ for key, value in state_dict.items() if key.startswith(prefix)
21
+ }
22
+ return state_dict
23
+
24
+ return func
25
+
26
+
27
+ def merge_matched_keys(
28
+ model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor]
29
+ ) -> dict[str, torch.Tensor]:
30
+ """
31
+ Args:
32
+ model_dict:
33
+ The state dict of the current model, which is going to load pretrained parameters
34
+ state_dict:
35
+ A dictionary of parameters from a pre-trained model.
36
+
37
+ Returns:
38
+ dict[str, torch.Tensor]:
39
+ The updated state dict, where parameters with matched keys and shape are
40
+ updated with values in `state_dict`.
41
+ """
42
+ pretrained_dict = {}
43
+ mismatch_keys = []
44
+ for key, value in state_dict.items():
45
+ if key in model_dict and model_dict[key].shape == value.shape:
46
+ pretrained_dict[key] = value
47
+ else:
48
+ mismatch_keys.append(key)
49
+ logger.info(
50
+ f"Loading pre-trained model, with mismatched keys {mismatch_keys}"
51
+ )
52
+ model_dict.update(pretrained_dict)
53
+ return model_dict
54
+
55
+
56
+ def load_pretrained_model(
57
+ model: nn.Module,
58
+ ckpt_or_state_dict: str | Path | dict[str, torch.Tensor],
59
+ state_dict_process_fn: Callable = merge_matched_keys
60
+ ) -> None:
61
+ state_dict = ckpt_or_state_dict
62
+ if not isinstance(state_dict, dict):
63
+ state_dict = torch.load(ckpt_or_state_dict, "cpu")
64
+
65
+ model_dict = model.state_dict()
66
+ state_dict = state_dict_process_fn(model_dict, state_dict)
67
+ model.load_state_dict(state_dict,strict=False)
68
+
69
+
70
+ def create_mask_from_length(
71
+ lengths: torch.Tensor, max_length: int | None = None
72
+ ):
73
+ if max_length is None:
74
+ max_length = max(lengths)
75
+ idxs = torch.arange(max_length).reshape(1, -1) # (1, max_length)
76
+ mask = idxs.to(lengths.device) < lengths.view(-1, 1)
77
+ # (1, max_length) < (batch_size, 1) -> (batch_size, max_length)
78
+ return mask
79
+
80
+
81
+ def loss_with_mask(
82
+ loss: torch.Tensor,
83
+ mask: torch.Tensor,
84
+ reduce: bool = True
85
+ ) -> torch.Tensor:
86
+ """
87
+ Apply a mask to the loss tensor and optionally reduce it.
88
+
89
+ Args:
90
+ loss: Tensor of shape (b, t, ...) representing the loss values.
91
+ mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.
92
+ reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,).
93
+
94
+ Returns:
95
+ torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,).
96
+ """
97
+ expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)]
98
+ expanded_mask = expanded_mask.expand_as(loss)
99
+ masked_loss = loss * expanded_mask
100
+
101
+ sum_dims = tuple(range(1, loss.ndim))
102
+ loss_sum = masked_loss.sum(dim=sum_dims)
103
+ mask_sum = expanded_mask.sum(dim=sum_dims)
104
+ loss = loss_sum / mask_sum
105
+
106
+ if reduce:
107
+ return loss.mean()
108
+ else:
109
+ return loss
110
+
111
+
112
+ def convert_pad_shape(pad_shape: list[list[int]]):
113
+ l = pad_shape[::-1]
114
+ pad_shape = [item for sublist in l for item in sublist]
115
+ return pad_shape
116
+
117
+
118
+ def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor):
119
+ device = duration.device
120
+
121
+ b, t_x, t_y = mask.shape
122
+ cum_duration = torch.cumsum(duration, 1)
123
+
124
+ cum_duration_flat = cum_duration.view(b * t_x)
125
+ path = create_mask_from_length(cum_duration_flat, t_y).float()
126
+ path = path.view(b, t_x, t_y)
127
+ # take the diff on the `t_x` axis
128
+ path = path - torch.nn.functional.pad(
129
+ path, convert_pad_shape([[0, 0], [1, 0], [0, 0]])
130
+ )[:, :-1]
131
+ path = path * mask
132
+ return path
133
+
134
+
135
+ def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int):
136
+ """
137
+ Adjusts the size of the specified dimension of tensor x to match `target_length`.
138
+
139
+ Args:
140
+ x:
141
+ Input tensor.
142
+ target_length:
143
+ Desired size of the specified dimension.
144
+ length_dim:
145
+ The dimension to modify.
146
+
147
+ Returns:
148
+ torch.Tensor: The adjusted tensor.
149
+ """
150
+ current_length = x.shape[length_dim]
151
+
152
+ if current_length > target_length:
153
+ # Truncate the tensor
154
+ slices = [slice(None)] * x.ndim
155
+ slices[length_dim] = slice(0, target_length)
156
+ return x[tuple(slices)]
157
+
158
+ elif current_length < target_length:
159
+ # Pad the tensor with zeros
160
+ pad_shape = list(x.shape)
161
+ pad_length = target_length - current_length
162
+
163
+ pad_shape[length_dim] = pad_length # Shape for left padding
164
+ padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
165
+
166
+ return torch.cat([x, padding], dim=length_dim)
167
+
168
+ return x
169
+
170
+
171
+ def concat_non_padding(
172
+ seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor,
173
+ mask2: torch.BoolTensor
174
+ ):
175
+ """
176
+ Args
177
+ seq1 : Tensor (B, L1, E)
178
+ First sequence.
179
+ mask1 : BoolTensor (B, L1)
180
+ True for valid tokens in seq1, False for padding.
181
+ seq2 : Tensor (B, L2, E)
182
+ Second sequence.
183
+ mask2 : BoolTensor (B, L2)
184
+ True for valid tokens in seq2, False for padding.
185
+
186
+ Returns
187
+ concat_seq : Tensor (B, L1+L2, E)
188
+ Both sequences concatenated; valid tokens are left-aligned,
189
+ padding on the right is 0.
190
+ concat_mask: BoolTensor (B, L1+L2)
191
+ Mask for the concatenated sequence.
192
+ perm : LongTensor (B, L1+L2)
193
+ Permutation that maps **original indices → new indices**.
194
+ Needed for restoring the original sequences.
195
+ """
196
+ mask1, mask2 = mask1.bool(), mask2.bool()
197
+ B, L1, E = seq1.shape
198
+ L2 = seq2.size(1)
199
+ L = L1 + L2
200
+
201
+ seq_cat = torch.cat([seq1, seq2], dim=1) # (B, L, E)
202
+ mask_cat = torch.cat([mask1, mask2], dim=1) # (B, L)
203
+
204
+ # ----- Key step: stable sort so that all valid tokens move to the left -----
205
+ # Padding positions get +L, guaranteeing the largest “score” → sorted to the end.
206
+ positions = torch.arange(L, device=seq_cat.device).unsqueeze(0) # (1, L)
207
+ sort_score = positions + (~mask_cat) * L
208
+ perm = sort_score.argsort(dim=1, stable=True) # (B, L)
209
+
210
+ # Build concatenated sequence & mask
211
+ gather_idx = perm.unsqueeze(-1).expand(-1, -1, E) # (B, L, E)
212
+ concat_seq = seq_cat.gather(1, gather_idx)
213
+ concat_mask = mask_cat.gather(1, perm)
214
+
215
+ # Explicitly zero out the right-hand padding region for safety
216
+ concat_seq = concat_seq * concat_mask.unsqueeze(-1)
217
+
218
+ return concat_seq, concat_mask, perm
219
+
220
+
221
+ def restore_from_concat(
222
+ concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor,
223
+ perm: torch.LongTensor
224
+ ):
225
+ """
226
+ Restore (seq1, seq2) from the concatenated sequence produced by
227
+ `concat_non_padding`, using the returned permutation `perm`.
228
+ Fully vectorised — no Python loops.
229
+ """
230
+ mask1, mask2 = mask1.bool(), mask2.bool()
231
+ B, L1 = mask1.shape
232
+ L2 = mask2.size(1)
233
+ E = concat_seq.size(-1)
234
+
235
+ # Inverse permutation: maps **new_idx → old_idx**
236
+ inv_perm = torch.empty_like(perm)
237
+ inv_perm.scatter_(
238
+ 1, perm,
239
+ torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1)
240
+ )
241
+
242
+ # Bring tokens back to their original order
243
+ gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E)
244
+ seq_cat_rec = concat_seq.gather(1, gather_idx) # (B, L1+L2, E)
245
+
246
+ # Split back into the two sequences and mask out padding positions
247
+ seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1)
248
+ seq1_restore = seq1_restore * mask1.unsqueeze(-1)
249
+ seq2_restore = seq2_restore * mask2.unsqueeze(-1)
250
+
251
+ return seq1_restore, seq2_restore
252
+
253
+
254
+ def contains_nan(data):
255
+ """check if data contains NaN"""
256
+ if isinstance(data, torch.Tensor):
257
+ return torch.isnan(data).any().item()
258
+ elif isinstance(data, np.ndarray):
259
+ return np.isnan(data).any()
260
+ elif isinstance(data, float):
261
+ return math.isnan(data)
262
+ elif isinstance(data, (list, tuple)):
263
+ return any(contains_nan(x) for x in data)
264
+ elif isinstance(data, dict):
265
+ return any(contains_nan(v) for v in data.values())
266
+ return False
267
+
268
+
269
+ def check_nan_in_batch(batch):
270
+ """check if batch contains NaN and return nan audio ids"""
271
+ assert type(batch)==dict,"batch type error"
272
+ nan_audio_ids=[]
273
+ audio_ids=batch["audio_id"]
274
+ audio_id2content={}
275
+ for idx,audio_id in enumerate(audio_ids):
276
+ content=[]
277
+ for k,v in batch.items():
278
+ if k=="audio_id":
279
+ continue
280
+ content.append(v[idx])
281
+ audio_id2content[audio_id]=content
282
+
283
+ for audio_id,content in audio_id2content.items():
284
+ if contains_nan(content):
285
+ nan_audio_ids.append(audio_id)
286
+ print(f"{audio_id} contains NaN")
287
+ return nan_audio_ids
288
+