guilinhu commited on
Commit
df9f13e
·
verified ·
1 Parent(s): 17dbfd4

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wandb/
2
+ __pycache__/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Guilin Hu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,54 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Proactive Hearing Assistants that Isolate Egocentric Conversations
2
+
3
+ ## More Information
4
+
5
+ For more information, please refer to our website: [https://proactivehearing.cs.washington.edu/](https://proactivehearing.cs.washington.edu/).
6
+
7
+ ## Abstract
8
+
9
+ We introduce proactive hearing assistants that automatically identify and separate the wearer’s conversation partners, without requiring explicit prompts. Our system operates on egocentric binaural audio and uses the wearer’s self-speech as an anchor, leveraging turn-taking behavior and dialogue dynamics to infer conversational partners and suppress others. To enable real-time, on-device operation, we propose a dual-model architecture: a lightweight streaming model runs every 12.5 ms for low-latency extraction of the conversation partners, while a slower model runs less frequently to capture longer-range conversational dynamics. Results on real-world 2- and 3-speaker conversation test sets, collected with binaural egocentric hardware from 11 participants totaling 6.8 hours, show generalization in identifying and isolating conversational partners in multi-conversation settings. Our work marks a step toward hearing assistants that adapt proactively to conversational dynamics and engagement.
10
+
11
+
12
+ ## Training and Evaluation
13
+
14
+ ### 1. Installing Requirements
15
+
16
+ Before training or evaluating the model, please create an environment and install all dependencies:
17
+
18
+ ```
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ ### 2. Model Training
23
+
24
+ To train the model, run:
25
+
26
+ ```
27
+ python src/train_joint.py --config <path_to_config> --run_dir <path_to_model_checkpoint>
28
+ ```
29
+
30
+ To resume training, make sure that <path_to_model_checkpoint> points to the same directory used previously, and rerun the command above.
31
+
32
+
33
+ ### 3. Model Evaluation
34
+
35
+ To evaluate the model, run:
36
+
37
+ ```
38
+ python eval.py <path to testing dataset> <path to model checkpoint> --use_cuda --save
39
+ ```
40
+
41
+
42
+ ## Citation
43
+
44
+ If you use our work, please cite:
45
+
46
+ ```
47
+ @inproceedings{hu2025proactive,
48
+ title={Proactive Hearing Assistants that Isolate Egocentric Conversations},
49
+ author={Hu, Guilin and Itani, Malek and Chen, Tuochao and Gollakota, Shyamnath},
50
+ booktitle={Proceedings of the 2025 Conference on Empirical Methods in Natural Language Processing},
51
+ pages={25377--25394},
52
+ year={2025}
53
+ }
54
+ ```
config/model_config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "project_name": "magic_hear",
3
+ "pl_module": "src.hl_module.joint_train_hl_module_new.PLModule",
4
+ "pl_module_args": {
5
+ "freeze_model1": false,
6
+ "metrics": [
7
+ "snr_i",
8
+ "si_snr_i",
9
+ "si_sdr_i"
10
+ ],
11
+ "model": "src.models.network.net_conversation_joint.Net_Conversation",
12
+ "model_params": {
13
+ "model1_block_name": "src.models.blocks.model1_block.GridNetBlock",
14
+ "num_layers_model1": 6,
15
+ "latent_dim_model1": 32,
16
+ "use_speaker_emb_model1": false,
17
+ "use_self_speech_model2": false,
18
+ "one_emb_model1": true,
19
+ "model1_block_params": {
20
+ "emb_ks": 2,
21
+ "emb_hs": 2,
22
+ "hidden_channels": 64,
23
+ "n_head": 4
24
+ },
25
+ "model2_block_name": "src.models.blocks.model2_block.GridNetBlock",
26
+ "num_layers_model2": 6,
27
+ "latent_dim_model2": 32,
28
+ "lstm_fold_chunk": 80,
29
+ "model2_block_params": {
30
+ "emb_ks": 1,
31
+ "emb_hs": 1,
32
+ "hidden_channels": 64,
33
+ "n_head": 4,
34
+ "use_attention": false
35
+ },
36
+ "stft_chunk_size": 200,
37
+ "stft_pad_size": 32,
38
+ "stft_back_pad": 32,
39
+ "num_input_channels": 1,
40
+ "num_output_channels": 1,
41
+ "num_sources": 1,
42
+ "use_sp_feats": false,
43
+ "use_first_ln": true,
44
+ "n_imics": 1,
45
+ "window": "rect",
46
+ "E": 2
47
+ },
48
+ "loss": "src.losses.SNRLP.SNRLPLoss",
49
+ "loss_params": {
50
+ "snr_loss_name": "snr",
51
+ "neg_weight": 100
52
+ },
53
+ "optimizer": "torch.optim.AdamW",
54
+ "optimizer_params": {
55
+ "lr": 2e-3
56
+ },
57
+ "scheduler": "torch.optim.lr_scheduler.ReduceLROnPlateau",
58
+ "scheduler_params": {
59
+ "mode": "min",
60
+ "patience": 4,
61
+ "factor": 0.5,
62
+ "min_lr": 1e-6
63
+ },
64
+ "sr": 16000,
65
+ "grad_clip": 1,
66
+ "use_dp": true
67
+ },
68
+ "train_dataset": "src.datasets.joint_training_dataset.Dataset",
69
+ "train_data_args": {
70
+ "input_dir": [],
71
+ "output_conversation": 1,
72
+ "batch_size": 4,
73
+ "clean_embed": true,
74
+ "random_audio_length": 160000,
75
+ "required_first_speaker_as_self_speech": true,
76
+ "spk_emb_exist": false
77
+ },
78
+ "val_dataset": "src.datasets.joint_training_dataset.Dataset",
79
+ "val_data_args": {
80
+ "input_dir": [],
81
+ "output_conversation": 1,
82
+ "batch_size": 4,
83
+ "clean_embed": true,
84
+ "random_audio_length": 160000,
85
+ "required_first_speaker_as_self_speech": true,
86
+ "spk_emb_exist": false
87
+ },
88
+ "epochs": 130,
89
+ "batch_size": 4,
90
+ "eval_batch_size": 4,
91
+ "num_workers": 12
92
+ }
eval.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.metrics.metrics import Metrics
2
+ import src.utils as utils
3
+ import argparse
4
+ import os, json, glob
5
+ import numpy as np
6
+ import torch
7
+ import pandas as pd
8
+ import torchaudio
9
+ import matplotlib.pyplot as plt
10
+ import torch.nn as nn
11
+ import copy
12
+ import torch.nn.functional as F
13
+ from torchmetrics.functional import signal_noise_ratio as snr
14
+
15
+
16
+ def mod_pad(x, chunk_size, pad):
17
+ mod = 0
18
+ if (x.shape[-1] % chunk_size) != 0:
19
+ mod = chunk_size - (x.shape[-1] % chunk_size)
20
+
21
+ x = F.pad(x, (0, mod))
22
+ x = F.pad(x, pad)
23
+
24
+ return x, mod
25
+
26
+
27
+ class LayerNormPermuted(nn.LayerNorm):
28
+ def __init__(self, *args, **kwargs):
29
+ super(LayerNormPermuted, self).__init__(*args, **kwargs)
30
+
31
+ def forward(self, x):
32
+ """
33
+ Args:
34
+ x: [B, C, T, F]
35
+ """
36
+ x = x.permute(0, 2, 3, 1) # [B, T, F, C]
37
+ x = super().forward(x)
38
+ x = x.permute(0, 3, 1, 2) # [B, C, T, F]
39
+ return x
40
+
41
+
42
+ def save_audio_file_torch(file_path, wavform, sample_rate=16000, rescale=False):
43
+ if rescale:
44
+ wavform = wavform / torch.max(wavform) * 0.9
45
+ torchaudio.save(file_path, wavform, sample_rate)
46
+
47
+
48
+ def get_mixture_and_gt(curr_dir, rng, SHIFT_VALUE=0, noise_audio_list=[]):
49
+ metadata2 = utils.read_json(os.path.join(curr_dir, "metadata.json"))
50
+ diags = metadata2["target_dialogue"]
51
+
52
+ if os.path.exists(os.path.join(curr_dir, "self_speech.wav")):
53
+ self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, "self_speech.wav"), 1)
54
+ elif os.path.exists(os.path.join(curr_dir, "self_speech_original.wav")):
55
+ self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, "self_speech_original.wav"), 1)
56
+
57
+ other_speech = torch.zeros_like(self_speech)
58
+
59
+ for i in range(len(diags) - 1):
60
+ wav = utils.read_audio_file_torch(os.path.join(curr_dir, f"target_speech{i}.wav"), 1)
61
+ other_speech += wav
62
+
63
+ if os.path.exists(os.path.join(curr_dir, f"intereference.wav")):
64
+ interfere = utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference.wav"), 1)
65
+ else:
66
+ interfere = torch.zeros_like(self_speech)
67
+ interfere += utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference0.wav"), 1)
68
+ interfere += utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference1.wav"), 1)
69
+
70
+ gt = self_speech + other_speech
71
+ tgt_snr = rng.uniform(-10, 10)
72
+ interfere = scale_noise_to_snr(gt, interfere, tgt_snr)
73
+
74
+ mixture = gt + interfere
75
+
76
+ if noise_audio_list != []:
77
+ print("added noise")
78
+ noise_audio = noise_sample(noise_audio_list, mixture.shape[-1], rng)
79
+ wham_scale = rng.uniform(0, 1)
80
+ mixture += noise_audio * wham_scale
81
+
82
+ embed_path = os.path.join(curr_dir, "embed.pt")
83
+ if os.path.exists(embed_path):
84
+ embed = torch.load(embed_path, weights_only=False)
85
+ embed = torch.from_numpy(embed)
86
+ else:
87
+ embed = torch.zeros(256)
88
+
89
+ L = mixture.shape[-1]
90
+
91
+ peak = np.abs(mixture).max()
92
+ if peak > 1:
93
+ mixture /= peak
94
+ self_speech /= peak
95
+ gt /= peak
96
+
97
+ inputs = {
98
+ "mixture": mixture.float(),
99
+ "embed": embed.float(),
100
+ "self_speech": self_speech[0:1, :].float(),
101
+ }
102
+
103
+ targets = {
104
+ "self": self_speech[0:1, :].numpy(),
105
+ "other": other_speech[0:1, :].numpy(),
106
+ "target": gt[0:1, :].float(),
107
+ }
108
+
109
+ return inputs, targets, metadata2
110
+
111
+
112
+ def scale_utterance(audio, timestamp, rng, db_change=7):
113
+ for start, end in timestamp:
114
+ if rng.uniform(0, 1) < 0.3:
115
+ random_db = rng.uniform(-db_change, db_change)
116
+ amplitude_factor = 10 ** (random_db / 20)
117
+ audio[..., start:end] *= amplitude_factor
118
+
119
+ return audio
120
+
121
+
122
+ def get_snr(target, mixture, EPS=1e-9):
123
+ """
124
+ Computes the average SNR across all channels
125
+ """
126
+ return snr(mixture, target).mean()
127
+
128
+
129
+ def scale_noise_to_snr(target_speech: torch.Tensor, noise: torch.Tensor, target_snr: float):
130
+ current_snr = get_snr(target_speech, noise + target_speech)
131
+
132
+ pwr = (current_snr - target_snr) / 20
133
+ k = 10**pwr
134
+
135
+ return k * noise
136
+
137
+
138
+ def run_testcase(model, inputs, device) -> np.ndarray:
139
+ with torch.inference_mode():
140
+ inputs["mixture"] = inputs["mixture"][0:1, ...].unsqueeze(0).to(device)
141
+ inputs["embed"] = inputs["embed"].unsqueeze(0).to(device)
142
+ inputs["self_speech"] = inputs["self_speech"][0:1, ...].unsqueeze(0).to(device)
143
+
144
+ inputs["start_idx"] = 0
145
+ inputs["end_idx"] = inputs["mixture"].shape[-1]
146
+ outputs = model(inputs)
147
+
148
+ output_target = outputs["output"].squeeze(0)
149
+
150
+ final_output = output_target.cpu().numpy()
151
+
152
+ return final_output
153
+
154
+
155
+ def get_timestamp_mask(timestamps, mask_shape):
156
+ mask = torch.zeros(mask_shape)
157
+ for s, e in timestamps:
158
+ mask[..., s:e] = 1
159
+
160
+ return mask
161
+
162
+
163
+ def noise_sample(noise_file_list, audio_length, rng: np.random.RandomState):
164
+ # NOTE: hardcoded. assume noise is 48k and target is 16k
165
+ target_sr = 16000
166
+
167
+ acc_len = 0
168
+ concatenated_audio = None
169
+ while acc_len <= audio_length:
170
+ noise_file = rng.choice(noise_file_list)
171
+ info = torchaudio.info(noise_file)
172
+ noise_sr = info.sample_rate
173
+
174
+ noise_wav, _ = torchaudio.load(noise_file)
175
+ noise_wav = noise_wav[0:1, ...]
176
+
177
+ if noise_sr != target_sr:
178
+ resampler = torchaudio.transforms.Resample(orig_freq=noise_sr, new_freq=target_sr)
179
+ noise_wav = resampler(noise_wav)
180
+
181
+ if concatenated_audio is None:
182
+ concatenated_audio = noise_wav
183
+ else:
184
+ concatenated_audio = torch.cat((concatenated_audio, noise_wav), dim=1)
185
+
186
+ acc_len = concatenated_audio.shape[-1]
187
+
188
+ concatenated_audio = concatenated_audio[..., :audio_length]
189
+
190
+ assert concatenated_audio.shape[1] == audio_length
191
+
192
+ return concatenated_audio
193
+
194
+
195
+ def main(args: argparse.Namespace):
196
+ device = "cuda" if args.use_cuda else "cpu"
197
+
198
+ # Load model
199
+ model = utils.load_torch_pretrained(args.run_dir).model
200
+ model_name = args.run_dir.split("/")[-1]
201
+ model = model.to(device)
202
+ model.eval()
203
+
204
+ # Initialize metrics
205
+ snr = Metrics("snr")
206
+ snr_i = Metrics("snr_i")
207
+
208
+ si_sdr = Metrics("si_sdr")
209
+
210
+ records = []
211
+
212
+ noise_audio_list = []
213
+ if args.noise_dir is not None:
214
+ noise_audio_sublist = glob.glob(os.path.join(args.noise_dir, "*.wav"))
215
+ if not noise_audio_sublist:
216
+ print("no noise file found")
217
+ noise_audio_list.extend(noise_audio_sublist)
218
+
219
+ for i in range(0, 200):
220
+ rng = np.random.RandomState(i)
221
+ dataset_name = os.path.basename(args.test_dir)
222
+ curr_dir = os.path.join(args.test_dir, "{:05d}".format(i))
223
+
224
+ meta_dir = os.path.join(curr_dir, "metadata.json")
225
+
226
+ if not os.path.exists(meta_dir):
227
+ continue
228
+
229
+ inputs, targets, metadata = get_mixture_and_gt(curr_dir, rng, noise_audio_list=noise_audio_list)
230
+
231
+ if inputs is None:
232
+ continue
233
+
234
+ self_timestamps = metadata["target_dialogue"][0]["timestamp"]
235
+
236
+ target_speech = targets["target"].cpu().numpy()
237
+ row = {"test_case_index": i}
238
+ mixture = inputs["mixture"].cpu().numpy()
239
+
240
+ self_speech = inputs["self_speech"].squeeze(0).cpu().numpy()
241
+
242
+ inputs["mixture"] = inputs["mixture"][0:1, ...]
243
+ target_speech = target_speech[0:1, ...]
244
+
245
+ output_target = run_testcase(model, inputs, device)
246
+
247
+ self_timestamps = metadata["target_dialogue"][0]["timestamp"]
248
+ self_mask = get_timestamp_mask(self_timestamps, target_speech.shape)
249
+ self_mask[..., : args.sr] = 0
250
+
251
+ if mixture.ndim == 1:
252
+ mixture = mixture[np.newaxis, ...]
253
+
254
+ total_input_sisdr = si_sdr(est=mixture[0:1], gt=target_speech, mix=mixture[0:1]).item()
255
+ total_output_sisdr = si_sdr(est=output_target, gt=target_speech, mix=mixture[0:1]).item()
256
+
257
+ row[f"sisdr_input_total"] = total_input_sisdr
258
+ row[f"sisdr_output_total"] = total_output_sisdr
259
+
260
+ # self
261
+
262
+ self_sisdr_mix = si_sdr(
263
+ est=self_mask * mixture[:1], gt=self_mask * target_speech, mix=self_mask * mixture[:1]
264
+ ).item()
265
+ self_sisdr_pred = si_sdr(
266
+ est=self_mask * output_target, gt=self_mask * target_speech, mix=self_mask * mixture[:1]
267
+ ).item()
268
+
269
+ row[f"sisdr_mix_self"] = self_sisdr_mix
270
+ row[f"sisdr_pred_self"] = self_sisdr_pred
271
+
272
+ # ======other speaker======
273
+
274
+ other_timestamps = metadata["target_dialogue"][1]["timestamp"]
275
+ if len(metadata["target_dialogue"]) > 2:
276
+ for j in range(2, len(metadata["target_dialogue"])):
277
+ timestamp = metadata["target_dialogue"][j]["timestamp"]
278
+ other_timestamps = other_timestamps + timestamp
279
+
280
+ other_mask = get_timestamp_mask(other_timestamps, target_speech.shape)
281
+ other_mask[..., : args.sr] = 0
282
+
283
+ other_sisdr_mix = si_sdr(
284
+ est=other_mask * mixture[:1], gt=other_mask * target_speech, mix=other_mask * mixture[:1]
285
+ ).item()
286
+ other_sisdr_pred = si_sdr(
287
+ est=other_mask * output_target, gt=other_mask * target_speech, mix=other_mask * mixture[:1]
288
+ ).item()
289
+
290
+ row[f"sisdr_mix_other"] = other_sisdr_mix
291
+ row[f"sisdr_pred_other"] = other_sisdr_pred
292
+
293
+ print(i)
294
+ records.append(row)
295
+
296
+ if noise_audio_list != []:
297
+ save_folder = f"./result_{dataset_name}_noise/{model_name}/{i}"
298
+ else:
299
+ save_folder = f"./result_{dataset_name}/{model_name}/{i}"
300
+ os.makedirs(save_folder, exist_ok=True)
301
+
302
+ if type(self_speech) == np.ndarray:
303
+ self_speech = torch.from_numpy(self_speech)
304
+
305
+ if self_speech.dim() == 1:
306
+ self_speech = self_speech.unsqueeze(0)
307
+
308
+ if args.save:
309
+ save_audio_file_torch(
310
+ f"{save_folder}/mix.wav", torch.from_numpy(mixture[0:1]), sample_rate=args.sr, rescale=False
311
+ )
312
+ save_audio_file_torch(f"{save_folder}/self.wav", self_speech, sample_rate=args.sr, rescale=False)
313
+ save_audio_file_torch(
314
+ f"{save_folder}/output_target.wav", torch.from_numpy(output_target), sample_rate=args.sr, rescale=False
315
+ )
316
+ save_audio_file_torch(
317
+ f"{save_folder}/target_speech.wav", torch.from_numpy(target_speech), sample_rate=args.sr, rescale=False
318
+ )
319
+
320
+ results_df = pd.DataFrame.from_records(records)
321
+
322
+ columns = ["test_case_index"] + [col for col in results_df.columns if col != "test_case_index"]
323
+ results_df = results_df[columns]
324
+
325
+ if noise_audio_list != []:
326
+ results_csv_path = f"./result_{dataset_name}_noise/{model_name}_multi.csv"
327
+ else:
328
+ results_csv_path = f"./result_{dataset_name}/{model_name}_multi.csv"
329
+ results_df.to_csv(results_csv_path, index=False)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("test_dir", type=str, help="Path to test dataset")
335
+
336
+ parser.add_argument("run_dir", type=str, help="Path to model run checkpoint")
337
+
338
+ parser.add_argument("--sr", type=int, default=16000, help="Project sampling rate")
339
+
340
+ parser.add_argument("--noise_dir", type=str, default=None, help="Wham noise directory")
341
+
342
+ parser.add_argument("--use_cuda", action="store_true", help="Whether to use cuda")
343
+
344
+ parser.add_argument("--save", action="store_true", help="Whether to save output audio")
345
+
346
+ main(parser.parse_args())
requirements.txt ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.16
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ antlr4-python3-runtime==4.9.3
7
+ asteroid==0.7.0
8
+ asteroid-filterbanks==0.4.0
9
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
10
+ async-timeout==5.0.1
11
+ attrs==25.3.0
12
+ audioread==3.0.1
13
+ auraloss==0.4.0
14
+ beautifulsoup4==4.13.4
15
+ cached-property==2.0.1
16
+ certifi==2025.1.31
17
+ cffi==1.17.1
18
+ cftime==1.6.4.post1
19
+ charset-normalizer==3.4.1
20
+ ci_sdr==0.0.2
21
+ click==8.1.8
22
+ coloredlogs==15.0.1
23
+ comm @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_comm_1753453984/work
24
+ ConfigArgParse==1.7
25
+ contourpy==1.3.0
26
+ ctc_segmentation==1.7.4
27
+ cycler==0.12.1
28
+ Cython==3.0.12
29
+ DateTime==5.5
30
+ debugpy @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_debugpy_1752827114/work
31
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
32
+ Distance==0.1.3
33
+ docker-pycreds==0.4.0
34
+ editdistance==0.8.1
35
+ einops==0.8.1
36
+ espnet==202412
37
+ espnet-tts-frontend==0.0.3
38
+ eval_type_backport==0.2.2
39
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
40
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
41
+ fast_bss_eval==0.1.3
42
+ filelock==3.18.0
43
+ flatbuffers==25.2.10
44
+ fonttools==4.57.0
45
+ frozenlist==1.5.0
46
+ fsspec==2025.3.2
47
+ g2p-en==2.1.0
48
+ gdown==5.2.0
49
+ gitdb==4.0.12
50
+ GitPython==3.1.44
51
+ grpcio==1.74.0
52
+ h5py==3.13.0
53
+ huggingface-hub==0.30.2
54
+ humanfriendly==10.0
55
+ hydra-core==1.3.2
56
+ HyperPyYAML==1.2.2
57
+ idna==3.10
58
+ importlib-metadata==4.13.0
59
+ importlib_resources==6.5.2
60
+ inflect==7.5.0
61
+ intervaltree==3.1.0
62
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1753749834440/work
63
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1701831663892/work
64
+ jaconv==0.4.0
65
+ jamo==0.4.1
66
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
67
+ Jinja2==3.1.6
68
+ jiwer==4.0.0
69
+ joblib==1.4.2
70
+ julius==0.2.7
71
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
72
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
73
+ kaldiio==2.18.1
74
+ kiwisolver==1.4.7
75
+ lazy_loader==0.4
76
+ librosa==0.9.2
77
+ lightning-utilities==0.14.3
78
+ llvmlite==0.43.0
79
+ Markdown==3.9
80
+ MarkupSafe==3.0.2
81
+ matplotlib==3.9.4
82
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
83
+ mir_eval==0.8.2
84
+ more-itertools==10.6.0
85
+ mpmath==1.3.0
86
+ msgpack==1.1.0
87
+ multidict==6.4.3
88
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
89
+ netCDF4==1.7.2
90
+ networkx==3.2.1
91
+ nltk==3.9.1
92
+ noisereduce==3.0.3
93
+ numba==0.60.0
94
+ numpy==1.23.5
95
+ nvidia-cublas-cu12==12.4.5.8
96
+ nvidia-cuda-cupti-cu12==12.4.127
97
+ nvidia-cuda-nvrtc-cu12==12.4.127
98
+ nvidia-cuda-runtime-cu12==12.4.127
99
+ nvidia-cudnn-cu12==9.1.0.70
100
+ nvidia-cufft-cu12==11.2.1.3
101
+ nvidia-curand-cu12==10.3.5.147
102
+ nvidia-cusolver-cu12==11.6.1.9
103
+ nvidia-cusparse-cu12==12.3.1.170
104
+ nvidia-cusparselt-cu12==0.6.2
105
+ nvidia-nccl-cu12==2.21.5
106
+ nvidia-nvjitlink-cu12==12.4.127
107
+ nvidia-nvtx-cu12==12.4.127
108
+ omegaconf==2.3.0
109
+ onnxruntime==1.19.2
110
+ openai-whisper==20250625
111
+ opt_einsum==3.4.0
112
+ packaging==24.2
113
+ pandas==2.2.3
114
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
115
+ pb-bss-eval==0.0.2
116
+ pesq==0.0.4
117
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
118
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
119
+ pillow==11.2.1
120
+ platformdirs==4.3.7
121
+ pooch==1.8.2
122
+ prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
123
+ propcache==0.3.1
124
+ protobuf==5.29.4
125
+ psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663125313/work
126
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
127
+ pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
128
+ pybind11==2.13.6
129
+ pycparser==2.22
130
+ pydantic==2.11.3
131
+ pydantic_core==2.33.1
132
+ pydub==0.25.1
133
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1750615794071/work
134
+ pyparsing==3.2.3
135
+ pypinyin==0.44.0
136
+ pyroomacoustics==0.8.3
137
+ PySocks==1.7.1
138
+ pystoi==0.4.1
139
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
140
+ python-sofa==0.2.0
141
+ pytorch-lightning==2.5.1
142
+ pytorch-ranger==0.1.1
143
+ pytz==2025.2
144
+ pyworld==0.3.5
145
+ PyYAML==6.0.2
146
+ pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1749898437650/work
147
+ RapidFuzz==3.13.0
148
+ regex==2024.11.6
149
+ requests==2.32.3
150
+ resampy==0.4.3
151
+ Resemblyzer==0.1.4
152
+ ruamel.yaml==0.18.15
153
+ ruamel.yaml.clib==0.2.12
154
+ safetensors==0.5.3
155
+ scikit-learn==1.6.1
156
+ scipy==1.13.1
157
+ sentencepiece==0.1.97
158
+ sentry-sdk==2.26.0
159
+ setproctitle==1.3.5
160
+ silero-vad==5.1.2
161
+ six @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_six_1753199211/work
162
+ smmap==5.0.2
163
+ sortedcontainers==2.4.0
164
+ soundfile==0.13.1
165
+ soupsieve==2.7
166
+ sox==1.5.0
167
+ soxbindings==1.2.3
168
+ soxr==0.5.0.post1
169
+ stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
170
+ sympy==1.13.1
171
+ tensorboard==2.20.0
172
+ tensorboard-data-server==0.7.2
173
+ tensorboardX==2.6.4
174
+ threadpoolctl==3.6.0
175
+ tiktoken==0.9.0
176
+ tokenizers==0.21.1
177
+ torch==2.6.0
178
+ torch-complex==0.4.4
179
+ torch-optimizer==0.1.0
180
+ torch-stoi==0.2.3
181
+ torchaudio==2.6.0
182
+ torchmetrics==0.11.4
183
+ torchvision==0.21.0
184
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003328568/work
185
+ tqdm==4.67.1
186
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
187
+ transformers==4.51.3
188
+ triton==3.2.0
189
+ typeguard==4.4.2
190
+ typing==3.7.4.3
191
+ typing-inspection==0.4.0
192
+ typing_extensions==4.13.2
193
+ tzdata==2025.2
194
+ Unidecode==1.3.8
195
+ urllib3==2.4.0
196
+ wandb==0.19.9
197
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
198
+ webrtcvad==2.0.10
199
+ Werkzeug==3.1.3
200
+ yarl==1.19.0
201
+ zipp==3.21.0
202
+ zope.interface==7.2
src/datasets/joint_training_dataset.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Torch dataset object for synthetically rendered
3
+ spatial data
4
+ """
5
+ import random
6
+
7
+ from typing import Tuple
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import numpy as np
12
+ import os, glob
13
+
14
+ import src.utils as utils
15
+ from .noise import WhitePinkBrownAugmentation
16
+ import torchaudio
17
+ from torchmetrics.functional import signal_noise_ratio as snr
18
+ from torch.utils.data._utils.collate import default_collate
19
+
20
+ MAX_LEN = 50
21
+
22
+ def save_audio_file_torch(file_path, wavform, sample_rate = 16000, rescale = False):
23
+ if rescale:
24
+ wavform = wavform/torch.max(wavform)*0.9
25
+ torchaudio.save(file_path, wavform, sample_rate)
26
+
27
+ def perturb_amplitude_db(audio, db_change=10):
28
+ random_db = np.random.uniform(-db_change, db_change)
29
+ amplitude_factor = 10 ** (random_db / 20)
30
+ audio = audio * amplitude_factor
31
+ return audio
32
+
33
+
34
+ def scale_to_tgt_pwr(audio: np.ndarray, timestamp, tgt_pwr_dB: float, EPS=1e-9):
35
+ segments = []
36
+ for start_time, end_time in timestamp:
37
+ start_time = max(0, start_time)
38
+ end_time = min(audio.size(-1), end_time)
39
+
40
+ segment = audio[..., start_time:end_time]
41
+ segments.append(segment)
42
+
43
+ # Concatenate segments
44
+ concatenated = torch.cat(segments, dim=-1)
45
+
46
+ avg_pwr = torch.mean(concatenated**2)
47
+ avg_pwr_dB = 10 * torch.log10(avg_pwr + EPS)
48
+ scale = 10 ** ((tgt_pwr_dB - avg_pwr_dB) / 20)
49
+
50
+ audio_scaled = scale * audio
51
+ concatenated_scaled=scale*concatenated
52
+
53
+ scaled_pwr_dB = 10 * torch.log10(torch.mean(concatenated_scaled**2) + EPS)
54
+
55
+
56
+ assert torch.abs(tgt_pwr_dB - scaled_pwr_dB) < 0.1
57
+
58
+ return audio_scaled
59
+
60
+
61
+ def scale_utterance(audio, timestamp, rng, db_change=7):
62
+ for start, end in timestamp:
63
+ if rng.uniform(0, 1) < 0.3:
64
+ random_db=rng.uniform(-db_change, db_change)
65
+ amplitude_factor = 10 ** (random_db / 20)
66
+ audio[..., start:end] *= amplitude_factor
67
+
68
+ return audio
69
+
70
+
71
+ def get_snr(target, mixture, EPS=1e-9):
72
+ """
73
+ Computes the average SNR across all channels
74
+ """
75
+ return snr(mixture, target).mean()
76
+
77
+
78
+ def scale_noise_to_snr(target_speech: torch.Tensor, noise: torch.Tensor, target_snr: float):
79
+ """
80
+ Rescales a BINAURAL noise signal to achieve an average SNR (across both channels) equal to target snr.
81
+ Let k be the noise scaling factor
82
+ SNR_tgt = (SNR_left_scaled + SNR_right_scaled) / 2 = 0.5 * (10 log(S_L^T S_L/S_N^T S_N) - 20 log(k) + 10 log(S_R^T S_R / N_R^T N_R) - 20 log(k))
83
+ = 0.5 * (SNR_left_unscaled + SNR_right_unscaled - 40 log(k)) = avg_snr_initial - 20 log (k)
84
+ """
85
+
86
+ current_snr = get_snr(target_speech, noise + target_speech)
87
+
88
+ pwr = (current_snr - target_snr) / 20
89
+ k = 10 ** pwr
90
+
91
+ return k * noise
92
+
93
+
94
+ def custom_collate_fn(batch):
95
+ """
96
+ batch: List of tuples (inputs_dict, targets_dict).
97
+ inputs_dict: Dictionary of inputs like 'mixture', 'embed', etc.
98
+ targets_dict: Dictionary of targets like 'target', 'masked_target', etc.
99
+ """
100
+
101
+ # Separate inputs and targets
102
+ inputs = [item[0] for item in batch] # item[0] contains the 'inputs' dict
103
+ targets = [item[1] for item in batch] # item[1] contains the 'targets' dict
104
+
105
+ # Process inputs - use default_collate for everything except 'self_timestamp'
106
+ collated_inputs = {}
107
+ for key in inputs[0].keys():
108
+ if key == 'self_timestamp':
109
+ # Handle self_timestamp as a list of lists (variable-length)
110
+ collated_inputs[key] = [item[key] for item in inputs]
111
+ else:
112
+ # For fixed-length tensors, stack them using default_collate
113
+ collated_inputs[key] = default_collate([item[key] for item in inputs])
114
+
115
+ # Process targets (normal fixed-length tensors)
116
+ collated_targets = default_collate(targets)
117
+
118
+ return collated_inputs, collated_targets
119
+
120
+
121
+ class Dataset(torch.utils.data.Dataset):
122
+ """
123
+ Dataset of mixed waveforms and their corresponding ground truth waveforms
124
+ recorded at different microphone.
125
+
126
+ Data format is a pair of Tensors containing mixed waveforms and
127
+ ground truth waveforms respectively. The tensor's dimension is formatted
128
+ as (n_microphone, duration).
129
+
130
+ Each scenario is represented by a folder. Multiple datapoints are generated per
131
+ scenario. This can be customized using the points_per_scenario parameter.
132
+ """
133
+ def __init__(self, input_dir, n_mics=1, sr=8000,
134
+ sig_len = 30, downsample = 1,
135
+ split = 'val', output_conversation = 0,
136
+ batch_size = 8,
137
+ clean_embed=False,
138
+ noise_dir = None,
139
+ random_audio_length=800,
140
+ required_first_speaker_as_self_speech=True,
141
+ spk_emb_exist=True,
142
+ amplitude_aug_range=0,
143
+ noise_amplitude_aug_range=7,
144
+ utter_db_aug=7,
145
+ input_mean="L",
146
+ min_snr=-10,
147
+ max_snr=10,
148
+ original_val=False,
149
+ apply_timestamp_aug=False,
150
+ snr_control=True
151
+ ):
152
+ super().__init__()
153
+
154
+ self.dirs = []
155
+ self.spk_emb_exist=spk_emb_exist
156
+ for _dir in input_dir:
157
+ dir_list = sorted(list(Path(_dir).glob('[0-9]*')))
158
+ for dest in dir_list:
159
+ meta_path = os.path.join(dest, 'metadata.json')
160
+ embed_path = os.path.join(dest, 'embed.pt')
161
+ self_speech_path=os.path.join(dest, 'self_speech.wav')
162
+
163
+ if self.spk_emb_exist and os.path.exists(meta_path) and os.path.exists(embed_path):
164
+ self.dirs.append(dest)
165
+ elif not self.spk_emb_exist and os.path.exists(meta_path):
166
+ self.dirs.append(dest)
167
+
168
+ self.noise_dirs = []
169
+ if noise_dir is not None:
170
+ for sub_dir in noise_dir:
171
+ noise_audio_list = glob.glob(os.path.join(sub_dir, '*.wav'))
172
+ if not noise_dir:
173
+ print("no noise file found")
174
+ self.noise_dirs.extend(noise_audio_list)
175
+
176
+
177
+ self.clean_embed = clean_embed
178
+ self.n_mics = n_mics
179
+ self.sig_len = int(sig_len*sr/downsample)
180
+ self.sr = sr
181
+ self.downsample = downsample
182
+ self.scales = [-3, 3]
183
+ self.output_conversation = output_conversation
184
+ self.apply_timestamp_aug = apply_timestamp_aug
185
+
186
+ # Data augmentation
187
+ ### calculate the stat
188
+ self.batch_size = batch_size
189
+ self.split = split
190
+ print(self.split, (len(self.dirs)//batch_size)*batch_size)
191
+
192
+ self.random_audio_length=random_audio_length
193
+ self.required_first_speaker_as_self_speech=required_first_speaker_as_self_speech
194
+
195
+ self.amplitude_aug_range=amplitude_aug_range
196
+ self.noise_amplitude_aug_range=noise_amplitude_aug_range
197
+
198
+ self.pwr_thresh = -60
199
+ self.min_snr=min_snr
200
+ self.max_snr=max_snr
201
+ self.utter_db_aug=utter_db_aug
202
+ self.input_mean=input_mean
203
+ self.original_val=original_val
204
+ self.snr_control=snr_control
205
+
206
+
207
+ def __len__(self) -> int:
208
+ return (len(self.dirs)//self.batch_size)*self.batch_size
209
+
210
+
211
+ def noise_sample(self, noise_file_list, audio_length, rng: np.random.RandomState):
212
+ # NOTE: hardcoded. assume noise is 48k and target is 16k
213
+ # noise_audio=utils.read_audio_file_torch(noise_file, 3)
214
+
215
+ target_sr = 16000
216
+
217
+ acc_len=0
218
+ concatenated_audio = None
219
+ while acc_len<=audio_length:
220
+ noise_file=rng.choice(noise_file_list)
221
+ info = torchaudio.info(noise_file)
222
+ noise_sr=info.sample_rate
223
+
224
+ noise_wav, _ = torchaudio.load(noise_file)
225
+ if noise_wav.shape[0]>1 and self.input_mean=="L":
226
+ noise_wav=noise_wav[0:1, ...]
227
+ elif noise_wav.shape[0]>1 and self.input_mean=="R":
228
+ noise_wav=noise_wav[1:2, ...]
229
+ elif noise_wav.shape[0]>1 and self.input_mean==True:
230
+ noise_wav=torch.mean(noise_wav, dim=0)
231
+ noise_wav=noise_wav.unsqueeze(0)
232
+
233
+ if noise_sr != target_sr:
234
+ resampler = torchaudio.transforms.Resample(orig_freq=noise_sr, new_freq=target_sr)
235
+ noise_wav = resampler(noise_wav)
236
+
237
+ if concatenated_audio is None:
238
+ concatenated_audio = noise_wav
239
+ else:
240
+ concatenated_audio = torch.cat((concatenated_audio, noise_wav), dim=1)
241
+
242
+ acc_len=concatenated_audio.shape[-1]
243
+
244
+
245
+ concatenated_audio=concatenated_audio[..., :audio_length]
246
+
247
+ assert concatenated_audio.shape[1]==audio_length
248
+
249
+ return concatenated_audio
250
+
251
+
252
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ """
254
+ Returns:
255
+ mixed_data - M x T
256
+ target_voice_data - M x T
257
+ window_idx_one_hot - 1-D
258
+ """
259
+
260
+ if self.split == 'train':
261
+ seed = idx + np.random.randint(1000000)
262
+ else:
263
+ seed = idx
264
+ rng = np.random.RandomState(seed)
265
+
266
+ curr_dir = self.dirs[idx%len(self.dirs)]
267
+ return self.get_mixture_and_gt(curr_dir, rng)
268
+
269
+ def diffuse_speech_pattern(self, audio: torch.Tensor, timestamps: list, rng: np.random.RandomState, beta=8000):
270
+ zero_segments = np.array([timestamps[0][0]] + [timestamps[i+1][0] - timestamps[i][1] for i in range(len(timestamps) - 1)] + [audio.shape[-1] - timestamps[-1][1]])
271
+ total_zeros = sum(zero_segments)
272
+
273
+ # Add noise "diffusion"
274
+ noise = rng.normal(loc=0, scale=beta)
275
+ zero_segments = zero_segments + noise
276
+
277
+ # Ensure all elements are still positive
278
+ zero_segments[zero_segments <= 0] = 1
279
+
280
+ # Normalize so that sum is 1
281
+ zero_segments = zero_segments / zero_segments.sum()
282
+ zero_segments = zero_segments * total_zeros
283
+
284
+ # Floor indices so that we don't exceed audio size
285
+ zero_segments = np.floor(zero_segments).astype(np.int32)
286
+
287
+ assert zero_segments.sum() <= total_zeros
288
+
289
+ # Fill in time stamps
290
+ new_audio = torch.zeros_like(audio)
291
+ start_index = 0
292
+ for z, (s, e) in zip(zero_segments[:-1], timestamps):
293
+ start_index += z
294
+ new_audio[..., start_index:start_index+(e-s)] = audio[..., s:e]
295
+ start_index += (e - s)
296
+
297
+ return new_audio
298
+
299
+
300
+ def process_audio(self, audio, timestamp, rng, utter_db_aug, tgt_pwr_dB):
301
+ if self.apply_timestamp_aug:
302
+ audio = self.diffuse_speech_pattern(audio, timestamp, rng, beta=16000)
303
+
304
+ if timestamp==[]:
305
+ return audio
306
+ else:
307
+ audio = scale_to_tgt_pwr(audio, timestamp, tgt_pwr_dB)
308
+ audio=scale_utterance(audio, timestamp, rng, utter_db_aug)
309
+ return audio
310
+
311
+
312
+ def get_mixture_and_gt(self, curr_dir, rng):
313
+ metadata2 = utils.read_json(os.path.join(curr_dir, 'metadata.json'))
314
+
315
+
316
+ # process self speech
317
+ self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, 'self_speech.wav'), 1, self.input_mean)
318
+ self_speech_original=None
319
+ if os.path.exists(os.path.join(curr_dir, 'self_speech_original.wav')):
320
+ self_speech_original=utils.read_audio_file_torch(os.path.join(curr_dir, 'self_speech_original.wav'), 1, self.input_mean)
321
+
322
+ self_timestamp=metadata2['target_dialogue'][0]['timestamp']
323
+
324
+ if self_speech_original is not None:
325
+ list_of_self=[self_speech, self_speech_original]
326
+ concat_self_speech=torch.cat(list_of_self, dim=0)
327
+ utterance_adj_concat_self=scale_utterance(concat_self_speech, self_timestamp, rng, self.utter_db_aug)
328
+ self_speech=utterance_adj_concat_self[0:1, ...]
329
+ self_speech_original=utterance_adj_concat_self[1:2, ...]
330
+ else:
331
+ self_speech=scale_utterance(self_speech, self_timestamp, rng, self.utter_db_aug)
332
+
333
+ # process interference speech
334
+ if os.path.exists(os.path.join(curr_dir, f'intereference.wav')):
335
+ interfere = utils.read_audio_file_torch(os.path.join(curr_dir, f'intereference.wav'), 1, self.input_mean)
336
+ scale = 0.8
337
+ else:
338
+ interfers = metadata2["interference"]
339
+ interfere = torch.zeros_like(self_speech)
340
+ if os.path.exists(os.path.join(curr_dir, f'intereference0.wav')):
341
+ for i in range(0, len(interfers)):
342
+ current_inter=utils.read_audio_file_torch(os.path.join(curr_dir, f'intereference{i}.wav'), 1, self.input_mean)
343
+ inter_timestamp=metadata2['interference'][i]['timestamp']
344
+
345
+ current_inter=scale_utterance(current_inter, inter_timestamp, rng, self.utter_db_aug)
346
+ interfere += current_inter
347
+ elif os.path.exists(os.path.join(curr_dir, f'interference0.wav')):
348
+ for i in range(0, len(interfers)):
349
+ current_inter= utils.read_audio_file_torch(os.path.join(curr_dir, f'interference{i}.wav'), 1, self.input_mean)
350
+ inter_timestamp=metadata2['interference'][i]['timestamp']
351
+
352
+ current_inter=scale_utterance(current_inter, inter_timestamp, rng, self.utter_db_aug)
353
+ interfere += current_inter
354
+ scale = 1
355
+
356
+ # process other speech
357
+ other_speech = torch.zeros_like(self_speech)
358
+ if self.output_conversation:
359
+ diags = metadata2["target_dialogue"]
360
+ for i in range(len(diags) - 1):
361
+ if os.path.exists(os.path.join(curr_dir, f'target_speech{i}.wav')):
362
+ wav = utils.read_audio_file_torch(os.path.join(curr_dir, f'target_speech{i}.wav'), 1, self.input_mean)
363
+ other_timestamp=metadata2['target_dialogue'][i+1]['timestamp']
364
+ wav=scale_utterance(wav, other_timestamp, rng, self.utter_db_aug)
365
+ other_speech += wav
366
+
367
+ elif os.path.exists(os.path.join(curr_dir, f'other_speech{i}.wav')):
368
+ wav = utils.read_audio_file_torch(os.path.join(curr_dir, f'other_speech{i}.wav'), 1, self.input_mean)
369
+ other_timestamp=metadata2['target_dialogue'][i+1]['timestamp']
370
+ wav=scale_utterance(wav, other_timestamp, rng, self.utter_db_aug)
371
+ other_speech += wav
372
+ else:
373
+ raise Exception("no audio file to load")
374
+
375
+ # add noise, e.g. WHAM
376
+ if self.noise_dirs!=[] and random.random() < 0.3:
377
+ audio_length=interfere.shape[1]
378
+ noise=self.noise_sample(self.noise_dirs, audio_length, rng)
379
+ wham_scale = rng.uniform(0, 1)
380
+ interfere += noise*wham_scale
381
+
382
+
383
+ if self_speech_original is not None:
384
+ gt = self_speech_original + other_speech
385
+ else:
386
+ gt = self_speech + other_speech
387
+
388
+ mixture=gt+interfere
389
+
390
+ if self.snr_control==True:
391
+ tgt_snr = rng.uniform(self.min_snr, self.max_snr)
392
+ noise = scale_noise_to_snr(gt, mixture - gt, tgt_snr)
393
+
394
+ mixture = noise + gt
395
+
396
+ noise_augmentor = WhitePinkBrownAugmentation(
397
+ max_white_level=1e-2, # Adjust as needed
398
+ max_pink_level=5e-2, # Adjust as needed
399
+ max_brown_level=5e-2 # Adjust as needed
400
+ )
401
+
402
+ if self.split=="train" and random.random() < 0.3:
403
+ mixture, gt = noise_augmentor(mixture, gt, rng)
404
+
405
+
406
+ reverb_path = os.path.join(curr_dir, f'embed.pt')
407
+
408
+ if self.spk_emb_exist:
409
+ embed = torch.load(reverb_path, weights_only=False)
410
+ embed = torch.from_numpy(embed)
411
+ else:
412
+ embed=torch.zeros(256)
413
+
414
+ self.output_conversation
415
+
416
+ input_length=self_speech.shape[1]
417
+
418
+ start_idx=rng.randint(input_length-self.random_audio_length)
419
+ end_idx=start_idx+self.random_audio_length
420
+
421
+ # ====peak normalization======
422
+ peak = torch.abs(mixture).max()
423
+ if peak > 1:
424
+ mixture /= peak
425
+ gt /= peak
426
+ self_speech /= peak
427
+
428
+
429
+ inputs = {
430
+ 'mixture': mixture.float(),
431
+ 'embed': embed.float(),
432
+ 'self_speech': self_speech[0:1, :].float(),
433
+ 'start_idx_list': start_idx,
434
+ 'end_idx_list': end_idx
435
+ }
436
+
437
+ targets = {
438
+ 'target': gt[0:1, :].float()
439
+ }
440
+
441
+ return inputs, targets
src/datasets/noise.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def generate_white_noise(noise_shape, max_level, rng: np.random.RandomState):
6
+ # Choose white noise level
7
+ white_noise_level = max_level * rng.rand()
8
+ # print(white_noise_level)
9
+ # Generate white noise
10
+ white_noise = white_noise_level*torch.from_numpy(rng.normal(0, 1, size=noise_shape)).float()
11
+
12
+ return white_noise
13
+
14
+ def generate_pink_noise(noise_shape, max_level, rng: np.random.RandomState):
15
+ # Choose pink noise level
16
+ pink_noise_level = max_level * rng.rand()
17
+ # print(pink_noise_level)
18
+
19
+ # Generate pink noise
20
+ pink_noise = powerlaw_psd_gaussian(1, noise_shape, random_state = 0)
21
+ pink_noise = pink_noise_level*torch.from_numpy(pink_noise).float()
22
+
23
+ return pink_noise
24
+
25
+ def generate_brown_noise(noise_shape, max_level, rng: np.random.RandomState):
26
+ # Choose brown noise level
27
+ brown_noise_level = max_level * rng.rand()
28
+ # print(brown_noise_level)
29
+
30
+ # Generate brown noise
31
+ brown_noise = powerlaw_psd_gaussian(2, noise_shape, random_state = 0)
32
+ brown_noise = brown_noise_level*torch.from_numpy(brown_noise).float()
33
+
34
+ return brown_noise
35
+
36
+ """Generate colored noise."""
37
+
38
+ from numpy import sqrt, newaxis, integer
39
+ from numpy.fft import irfft, rfftfreq
40
+ from numpy.random import default_rng, Generator, RandomState
41
+ from numpy import sum as npsum
42
+
43
+
44
+ def powerlaw_psd_gaussian(exponent, size, fmin=0, random_state=None):
45
+ """Gaussian (1/f)**beta noise.
46
+
47
+ Based on the algorithm in:
48
+ Timmer, J. and Koenig, M.:
49
+ On generating power law noise.
50
+ Astron. Astrophys. 300, 707-710 (1995)
51
+
52
+ Normalised to unit variance
53
+
54
+ Parameters:
55
+ -----------
56
+
57
+ exponent : float
58
+ The power-spectrum of the generated noise is proportional to
59
+
60
+ S(f) = (1 / f)**beta
61
+ flicker / pink noise: exponent beta = 1
62
+ brown noise: exponent beta = 2
63
+
64
+ Furthermore, the autocorrelation decays proportional to lag**-gamma
65
+ with gamma = 1 - beta for 0 < beta < 1.
66
+ There may be finite-size issues for beta close to one.
67
+
68
+ shape : int or iterable
69
+ The output has the given shape, and the desired power spectrum in
70
+ the last coordinate. That is, the last dimension is taken as time,
71
+ and all other components are independent.
72
+
73
+ fmin : float, optional
74
+ Low-frequency cutoff.
75
+ Default: 0 corresponds to original paper.
76
+
77
+ The power-spectrum below fmin is flat. fmin is defined relative
78
+ to a unit sampling rate (see numpy's rfftfreq). For convenience,
79
+ the passed value is mapped to max(fmin, 1/samples) internally
80
+ since 1/samples is the lowest possible finite frequency in the
81
+ sample. The largest possible value is fmin = 0.5, the Nyquist
82
+ frequency. The output for this value is white noise.
83
+
84
+ random_state : int, numpy.integer, numpy.random.Generator, numpy.random.RandomState,
85
+ optional
86
+ Optionally sets the state of NumPy's underlying random number generator.
87
+ Integer-compatible values or None are passed to np.random.default_rng.
88
+ np.random.RandomState or np.random.Generator are used directly.
89
+ Default: None.
90
+
91
+ Returns
92
+ -------
93
+ out : array
94
+ The samples.
95
+
96
+
97
+ Examples:
98
+ ---------
99
+
100
+ # generate 1/f noise == pink noise == flicker noise
101
+ >>> import colorednoise as cn
102
+ >>> y = cn.powerlaw_psd_gaussian(1, 5)
103
+ """
104
+
105
+ # Make sure size is a list so we can iterate it and assign to it.
106
+ try:
107
+ size = list(size)
108
+ except TypeError:
109
+ size = [size]
110
+
111
+ # The number of samples in each time series
112
+ samples = size[-1]
113
+
114
+ # Calculate Frequencies (we asume a sample rate of one)
115
+ # Use fft functions for real output (-> hermitian spectrum)
116
+ f = rfftfreq(samples)
117
+
118
+ # Validate / normalise fmin
119
+ if 0 <= fmin <= 0.5:
120
+ fmin = max(fmin, 1./samples) # Low frequency cutoff
121
+ else:
122
+ raise ValueError("fmin must be chosen between 0 and 0.5.")
123
+
124
+ # Build scaling factors for all frequencies
125
+ s_scale = f
126
+ ix = npsum(s_scale < fmin) # Index of the cutoff
127
+ if ix and ix < len(s_scale):
128
+ s_scale[:ix] = s_scale[ix]
129
+ s_scale = s_scale**(-exponent/2.)
130
+
131
+ # Calculate theoretical output standard deviation from scaling
132
+ w = s_scale[1:].copy()
133
+ w[-1] *= (1 + (samples % 2)) / 2. # correct f = +-0.5
134
+ sigma = 2 * sqrt(npsum(w**2)) / samples
135
+
136
+ # Adjust size to generate one Fourier component per frequency
137
+ size[-1] = len(f)
138
+
139
+ # Add empty dimension(s) to broadcast s_scale along last
140
+ # dimension of generated random power + phase (below)
141
+ dims_to_add = len(size) - 1
142
+ s_scale = s_scale[(newaxis,) * dims_to_add + (Ellipsis,)]
143
+
144
+ # prepare random number generator
145
+ normal_dist = _get_normal_distribution(random_state)
146
+
147
+ # Generate scaled random power + phase
148
+ sr = normal_dist(scale=s_scale, size=size)
149
+ si = normal_dist(scale=s_scale, size=size)
150
+
151
+ # If the signal length is even, frequencies +/- 0.5 are equal
152
+ # so the coefficient must be real.
153
+ if not (samples % 2):
154
+ si[..., -1] = 0
155
+ sr[..., -1] *= sqrt(2) # Fix magnitude
156
+
157
+ # Regardless of signal length, the DC component must be real
158
+ si[..., 0] = 0
159
+ sr[..., 0] *= sqrt(2) # Fix magnitude
160
+
161
+ # Combine power + corrected phase to Fourier components
162
+ s = sr + 1J * si
163
+
164
+ # Transform to real time series & scale to unit variance
165
+ y = irfft(s, n=samples, axis=-1) / sigma
166
+
167
+ return y
168
+
169
+
170
+ def _get_normal_distribution(random_state):
171
+ normal_dist = None
172
+ if isinstance(random_state, (integer, int)) or random_state is None:
173
+ random_state = default_rng(random_state)
174
+ normal_dist = random_state.normal
175
+ elif isinstance(random_state, (Generator, RandomState)):
176
+ normal_dist = random_state.normal
177
+ else:
178
+ raise ValueError(
179
+ "random_state must be one of integer, numpy.random.Generator, or None"
180
+ "numpy.random.Randomstate"
181
+ )
182
+ return normal_dist
183
+
184
+
185
+ class WhitePinkBrownAugmentation:
186
+ def __init__(self, max_white_level=1e-3, max_pink_level=5e-3, max_brown_level=5e-3):
187
+ """
188
+ max_shift: Maximum shift (inclusive) in both directions
189
+ unique: Whether the same shift across channels is unique
190
+ """
191
+ self.max_white_level = max_white_level
192
+ self.max_pink_level = max_pink_level
193
+ self.max_brown_level = max_brown_level
194
+
195
+ def __call__(self, audio_data, gt_audio, rng: np.random.RandomState):
196
+ wn = generate_white_noise(audio_data.shape, self.max_white_level, rng)
197
+ pn = generate_pink_noise(audio_data.shape, self.max_pink_level, rng)
198
+ bn = generate_brown_noise(audio_data.shape, self.max_brown_level, rng)
199
+ # print("ssss")
200
+ augmented_audio = audio_data + (wn + pn + bn)
201
+
202
+ return augmented_audio, gt_audio
src/hl_module/joint_train_hl_module_new.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import wandb
7
+ import torch
8
+ from numpy import mean
9
+ from src.metrics.metrics import Metrics
10
+ import src.utils as utils
11
+ import numpy as np
12
+
13
+
14
+ class FakeModel(nn.Module):
15
+ def __init__(self, model):
16
+ super(FakeModel, self).__init__()
17
+ self.model = model
18
+
19
+
20
+ class PLModule(object):
21
+ def __init__(
22
+ self,
23
+ model,
24
+ model_params,
25
+ sr,
26
+ optimizer,
27
+ optimizer_params,
28
+ scheduler=None,
29
+ scheduler_params=None,
30
+ loss=None,
31
+ loss_params=None,
32
+ metrics=[],
33
+ slow_model_ckpt=None,
34
+ prev_ckpt=None,
35
+ grad_clip=None,
36
+ use_dp=True,
37
+ val_log_interval=10, # Unused, only kept for compatibility TODO: Remove
38
+ samples_per_speaker_number=3,
39
+ freeze_model1=False,
40
+ ):
41
+ self.model = utils.import_attr(model)(**model_params)
42
+
43
+ self.use_dp = use_dp
44
+ if use_dp:
45
+ self.model = nn.DataParallel(self.model)
46
+
47
+ self.sr = sr
48
+
49
+ # Log a val sample every this many intervals
50
+ # self.val_log_interval = val_log_interval
51
+ self.samples_per_speaker_number = samples_per_speaker_number
52
+
53
+ # Initialize metrics
54
+ self.metrics = [Metrics(metric) for metric in metrics]
55
+
56
+ # Metric values
57
+ self.metric_values = {}
58
+
59
+ # Dataset statistics
60
+ self.statistics = {}
61
+
62
+ # Assine metric to monitor, and how to judge different models based on it
63
+ # i.e. How do we define the best model (Here, we minimize val loss)
64
+ self.monitor = "val/loss"
65
+ self.monitor_mode = "min"
66
+
67
+ # Mode, either train or val
68
+ self.mode = None
69
+
70
+ self.val_samples = {}
71
+ self.train_samples = {}
72
+
73
+ self.input_snr_calculated = False
74
+ self.input_snr = []
75
+ self.snr_metric = Metrics("snr")
76
+
77
+ # Initialize loss function
78
+ self.loss_fn = utils.import_attr(loss)(**loss_params)
79
+
80
+ # Initaize weights if checkpoint is provided
81
+
82
+ # prev ckpt is for the checkpoint of the complete joint model (fast+slow) you want to train from
83
+ if prev_ckpt is not None:
84
+ if prev_ckpt.endswith(".ckpt"):
85
+ print("load prev model", prev_ckpt)
86
+ state = torch.load(prev_ckpt)["state_dict"]
87
+ # print(state.keys())
88
+ print(state["current_epoch"])
89
+ if self.use_dp:
90
+ _model = self.model.module
91
+ else:
92
+ _model = self.model
93
+
94
+ mdl = FakeModel(_model)
95
+ mdl.load_state_dict(state)
96
+ self.model = nn.DataParallel(mdl.model)
97
+ else:
98
+ print("load prev model", prev_ckpt)
99
+
100
+ state = torch.load(prev_ckpt)
101
+ print(state["current_epoch"])
102
+ state = state["model"]
103
+ if self.use_dp:
104
+ self.model.module.load_state_dict(state)
105
+ else:
106
+ self.model.load_state_dict(state)
107
+
108
+ # init ckpt stands for the slow model's initial weights checkpoint path
109
+ elif slow_model_ckpt is not None:
110
+ print(f"Loading model 1 weights from checkpoint: {slow_model_ckpt}")
111
+ model1_ckpt = torch.load(slow_model_ckpt)
112
+ print("current epoch is {}".format(model1_ckpt["current_epoch"]))
113
+
114
+ model1_state_dict = {
115
+ key.replace("tce_model.", ""): value
116
+ for key, value in model1_ckpt["model"].items()
117
+ if key.startswith("tce_model.")
118
+ }
119
+
120
+ if self.use_dp:
121
+ self.model.module.model1.load_state_dict(model1_state_dict, strict=False)
122
+ else:
123
+ self.model.model1.load_state_dict(model1_state_dict, strict=False)
124
+
125
+ else:
126
+ print("Loading model from scratch, no slow model init ckpt or joint model init ckpt")
127
+
128
+ # whether freeze slow model during training
129
+ self.freeze = freeze_model1
130
+ if freeze_model1:
131
+ self.freeze_model1()
132
+ params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters())
133
+ # Initialize optimizer
134
+ self.optimizer = utils.import_attr(optimizer)(params_to_optimize, **optimizer_params)
135
+ self.optim_name = optimizer
136
+ self.opt_params = optimizer_params
137
+ else:
138
+ # Initialize optimizer
139
+ self.optimizer = utils.import_attr(optimizer)(self.model.parameters(), **optimizer_params)
140
+ self.optim_name = optimizer
141
+ self.opt_params = optimizer_params
142
+
143
+ # Grad clip
144
+ self.grad_clip = grad_clip
145
+
146
+ if self.grad_clip is not None:
147
+ print(f"USING GRAD CLIP: {self.grad_clip}")
148
+ else:
149
+ print("ERROR! NOT USING GRAD CLIP" * 100)
150
+
151
+ # Initialize scheduler
152
+ self.scheduler = self.init_scheduler(scheduler, scheduler_params)
153
+ self.scheduler_name = scheduler
154
+ self.scheduler_params = scheduler_params
155
+
156
+ self.epoch = 0
157
+
158
+ def freeze_model1(self):
159
+ """Freezes the weights of model1."""
160
+ print("Freezing model1 weights")
161
+ model1 = self.model.module.model1 if self.use_dp else self.model.model1
162
+ for param in model1.parameters():
163
+ param.requires_grad = False
164
+ print("Model1 weights frozen.")
165
+
166
+ def load_state(self, path, map_location=None):
167
+ state = torch.load(path, map_location=map_location)
168
+
169
+ if self.use_dp:
170
+ self.model.module.load_state_dict(state["model"])
171
+ else:
172
+ self.model.load_state_dict(state["model"])
173
+
174
+ # Re-initialize optimizer
175
+ if not self.freeze:
176
+ self.optimizer = utils.import_attr(self.optim_name)(self.model.parameters(), **self.opt_params)
177
+ else:
178
+ params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters())
179
+ self.optimizer = utils.import_attr(self.optim_name)(params_to_optimize, **self.opt_params)
180
+
181
+ # Re-initialize scheduler (Order might be important?)
182
+ if self.scheduler is not None:
183
+ self.scheduler = self.init_scheduler(self.scheduler_name, self.scheduler_params)
184
+
185
+ self.optimizer.load_state_dict(state["optimizer"])
186
+
187
+ if self.scheduler is not None:
188
+ self.scheduler.load_state_dict(state["scheduler"])
189
+
190
+ self.epoch = state["current_epoch"]
191
+ print("Load model from epoch", self.epoch)
192
+ self.metric_values = state["metric_values"]
193
+
194
+ if "statistics" in self.statistics:
195
+ self.statistics = state["statistics"]
196
+
197
+ def dump_state(self, path):
198
+ if self.use_dp:
199
+ _model = self.model.module
200
+ else:
201
+ _model = self.model
202
+
203
+ state = dict(
204
+ model=_model.state_dict(),
205
+ optimizer=self.optimizer.state_dict(),
206
+ current_epoch=self.epoch,
207
+ metric_values=self.metric_values,
208
+ statistics=self.statistics,
209
+ )
210
+
211
+ if self.scheduler is not None:
212
+ state["scheduler"] = self.scheduler.state_dict()
213
+ print("save to " + path)
214
+ torch.save(state, path)
215
+
216
+ def get_current_lr(self):
217
+ for param_group in self.optimizer.param_groups:
218
+ return param_group["lr"]
219
+
220
+ def on_epoch_start(self):
221
+ print()
222
+ print("=" * 25, "STARTING EPOCH", self.epoch, "=" * 25)
223
+ print()
224
+
225
+ def get_avg_metric_at_epoch(self, metric, epoch=None):
226
+ if epoch is None:
227
+ epoch = self.epoch
228
+
229
+ return self.metric_values[epoch][metric]["epoch"] / self.metric_values[epoch][metric]["num_elements"]
230
+
231
+ def on_epoch_end(self, best_path, wandb_run):
232
+ assert self.epoch + 1 == len(
233
+ self.metric_values
234
+ ), "Current epoch must be equal to length of metrics (0-indexed)"
235
+
236
+ monitor_metric_last = self.get_avg_metric_at_epoch(self.monitor)
237
+
238
+ # Go over all epochs
239
+ save = True
240
+ for epoch in range(len(self.metric_values) - 1):
241
+ monitor_metric_at_epoch = self.get_avg_metric_at_epoch(self.monitor, epoch)
242
+
243
+ if self.monitor_mode == "max":
244
+ # If there is any model with monitor larger than current, then
245
+ # this is not the best model
246
+ if monitor_metric_last < monitor_metric_at_epoch:
247
+ save = False
248
+ break
249
+
250
+ if self.monitor_mode == "min":
251
+ # If there is any model with monitor smaller than current, then
252
+ # this is not the best model
253
+ if monitor_metric_last > monitor_metric_at_epoch:
254
+ save = False
255
+ break
256
+
257
+ # If this is best, save it
258
+ if save:
259
+ print("Current checkpoint is the best! Saving it...")
260
+ self.dump_state(best_path)
261
+
262
+ val_loss = self.get_avg_metric_at_epoch("val/loss")
263
+ val_snr_i = self.get_avg_metric_at_epoch("val/snr_i")
264
+ val_si_snr_i = self.get_avg_metric_at_epoch("val/si_snr_i")
265
+
266
+ print(f"Val loss: {val_loss:.02f}")
267
+ print(f"Val SNRi: {val_snr_i:.02f}dB")
268
+ print(f"Val SI-SDRi: {val_si_snr_i:.02f}dB")
269
+
270
+ # Log stuff on wandb
271
+ wandb_run.log({"lr-Adam": self.get_current_lr()}, commit=False, step=self.epoch + 1)
272
+
273
+ for metric in self.metric_values[self.epoch]:
274
+ wandb_run.log({metric: self.get_avg_metric_at_epoch(metric)}, commit=False, step=self.epoch + 1)
275
+
276
+ for statistic in self.statistics:
277
+ if not self.statistics[statistic]["logged"]:
278
+ data = self.statistics[statistic]["data"]
279
+ reduction = self.statistics[statistic]["reduction"]
280
+ if reduction == "mean":
281
+ val = mean(data)
282
+ elif reduction == "sum":
283
+ val = sum(data)
284
+ elif reduction == "histogram":
285
+ data = [[d] for d in data]
286
+ table = wandb.Table(data=data, columns=[statistic])
287
+ val = wandb.plot.histogram(table, statistic, title=statistic)
288
+ else:
289
+ assert 0, f"Unknown reduction {reduction}."
290
+ wandb_run.log({statistic: val}, commit=False)
291
+ self.statistics[statistic]["logged"] = True
292
+
293
+ wandb_run.log({"epoch": self.epoch}, commit=True, step=self.epoch + 1)
294
+
295
+ if self.scheduler is not None:
296
+ if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
297
+ # Get last metric
298
+ self.scheduler.step(monitor_metric_last)
299
+ else:
300
+ self.scheduler.step()
301
+
302
+ self.epoch += 1
303
+
304
+ def log_statistic(self, name, value, reduction="mean"):
305
+ if name not in self.statistics:
306
+ self.statistics[name] = dict(logged=False, data=[], reduction=reduction)
307
+
308
+ self.statistics[name]["data"].append(value)
309
+
310
+ def log_metric(self, name, value, batch_size=1, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True):
311
+ """
312
+ Logs a metric
313
+ value must be the AVERAGE value across the batch
314
+ Must provide batch size for accurate average computation
315
+ """
316
+
317
+ epoch_str = self.epoch
318
+ if epoch_str not in self.metric_values:
319
+ self.metric_values[epoch_str] = {}
320
+
321
+ if name not in self.metric_values[epoch_str]:
322
+ self.metric_values[epoch_str][name] = dict(step=None, epoch=None)
323
+
324
+ if type(value) == torch.Tensor:
325
+ value = value.item()
326
+
327
+ if on_step:
328
+ if self.metric_values[epoch_str][name]["step"] is None:
329
+ self.metric_values[epoch_str][name]["step"] = []
330
+
331
+ self.metric_values[epoch_str][name]["step"].append(value)
332
+
333
+ if on_epoch:
334
+ if self.metric_values[epoch_str][name]["epoch"] is None:
335
+ self.metric_values[epoch_str][name]["epoch"] = 0
336
+ self.metric_values[epoch_str][name]["num_elements"] = 0
337
+
338
+ self.metric_values[epoch_str][name]["epoch"] += value * batch_size
339
+ self.metric_values[epoch_str][name]["num_elements"] += batch_size
340
+
341
+ def val_naive(self, batch, batch_idx):
342
+ inputs, targets = batch
343
+ a = torch.cuda.memory_allocated(inputs["mixture"].device)
344
+ outputs = self.model(inputs)
345
+ b = torch.cuda.memory_allocated(inputs["mixture"].device)
346
+ print("Infer consume M", (b - a) / 1e6)
347
+
348
+ return outputs
349
+
350
+ def train_naive(self, batch, batch_idx):
351
+ self.reset_grad()
352
+ inputs, targets = batch
353
+ a = torch.cuda.memory_allocated(inputs["mixture"].device)
354
+ # print("a", a/1e9 )
355
+ outputs = self.model(inputs)
356
+
357
+ est = outputs["output"]
358
+ gt = targets["target"]
359
+
360
+ # Compute loss
361
+ loss = self.loss_fn(est=est, gt=gt).mean()
362
+ b = torch.cuda.memory_allocated(inputs["mixture"].device)
363
+
364
+ loss.backward(retain_graph=True)
365
+ c = torch.cuda.memory_allocated(inputs["mixture"].device)
366
+
367
+ self.backprop()
368
+ d = torch.cuda.memory_allocated(inputs["mixture"].device)
369
+
370
+ print("Training consume G", (b - a) / 1e9, (c - a) / 1e9, (d - c) / 1e9, a / 1e9)
371
+ return outputs
372
+
373
+ def silence_audio(self, input, timestamp):
374
+ output_audio = input.clone()
375
+ for start, end in timestamp:
376
+ output_audio[start:end] = 0.0
377
+
378
+ return output_audio
379
+
380
+ def _step(self, batch, batch_idx, step="train"):
381
+ inputs, targets = batch
382
+ batch_size = inputs["mixture"].shape[0]
383
+
384
+ start_idx = inputs["start_idx_list"][0].item()
385
+ end_idx = inputs["end_idx_list"][0].item()
386
+ inputs["start_idx"] = start_idx
387
+ inputs["end_idx"] = end_idx
388
+
389
+ outputs = self.model(inputs)
390
+ est = outputs["output"].clone()
391
+
392
+ if "audio_range" in outputs:
393
+ audio_range = outputs["audio_range"]
394
+ start_indices = audio_range[:, 0] # Shape: [batch]
395
+ end_indices = audio_range[:, 1]
396
+ sliced_gt = []
397
+ sliced_mix = []
398
+ sliced_self = []
399
+ # masked_est_list=[]
400
+
401
+ gt_clone = targets["target"].clone()
402
+ mix_clone = inputs["mixture"][:, 0:1].clone()
403
+ full_self_speech_clone = inputs["self_speech"].clone()
404
+
405
+ for index in range(est.size(0)):
406
+ start = start_indices[index].item()
407
+ end = end_indices[index].item()
408
+
409
+ sliced_gt.append(gt_clone[index, :, start:end])
410
+ sliced_mix.append(mix_clone[index, :, start:end])
411
+ sliced_self.append(full_self_speech_clone[index, :, start:end])
412
+
413
+ # Stack the sliced audio to form the final tensor
414
+ gt = torch.stack(sliced_gt, dim=0)
415
+ mix = torch.stack(sliced_mix, dim=0)
416
+ self_speech_final = torch.stack(sliced_self, dim=0)
417
+
418
+ else:
419
+ mix = inputs["mixture"][:, 0:1].clone()
420
+ gt = targets["target"].clone()
421
+ self_speech_final = targets["self_speech"].clone()
422
+
423
+ # Compute loss
424
+ loss = self.loss_fn(est=est, gt=gt).mean()
425
+
426
+ est_detached = est.detach().clone()
427
+
428
+ with torch.no_grad():
429
+ # Log loss
430
+ self.log_metric(
431
+ f"{step}/loss",
432
+ loss.item(),
433
+ batch_size=batch_size,
434
+ on_step=(step == "train"),
435
+ on_epoch=True,
436
+ prog_bar=True,
437
+ sync_dist=True,
438
+ )
439
+
440
+ # Log metrics
441
+ for metric in self.metrics:
442
+ if step == "train" and (metric.name == "PESQ" or metric.name == "STOI"):
443
+ continue
444
+ metric_val = metric(est=est_detached, gt=gt, mix=mix, self_speech=self_speech_final)
445
+ for i in range(batch_size):
446
+ # if gt is all zero, cannot compute metric
447
+ if torch.all(gt[i] == 0):
448
+ # print(f"Skipping sample {i} in batch because gt is all zeros.")
449
+ continue
450
+ val = metric_val[i].item()
451
+ self.log_metric(
452
+ f"{step}/{metric.name}",
453
+ val,
454
+ batch_size=1,
455
+ on_step=False,
456
+ on_epoch=True,
457
+ prog_bar=True,
458
+ sync_dist=True,
459
+ )
460
+
461
+ # Create collection of things to show in a sample on wandb
462
+ sample = {
463
+ "mixture": mix,
464
+ "output": est_detached,
465
+ "target": gt,
466
+ }
467
+
468
+ return loss, sample
469
+
470
+ def train(self):
471
+ self.model.train()
472
+ self.mode = "train"
473
+
474
+ def eval(self):
475
+ self.model.eval()
476
+ self.mode = "val"
477
+
478
+ def training_step(self, batch, batch_idx):
479
+ loss, sample = self._step(batch, batch_idx, step="train")
480
+
481
+ target = sample["target"]
482
+
483
+ return loss, target.shape[0]
484
+
485
+ def validation_step(self, batch, batch_idx):
486
+ loss, sample = self._step(batch, batch_idx, step="val")
487
+
488
+ target = sample["target"]
489
+
490
+ return loss, target.shape[0]
491
+
492
+ def reset_grad(self):
493
+ self.optimizer.zero_grad()
494
+
495
+ def backprop(self):
496
+ # print("BACKPROP")
497
+ # print(self.grad_clip)
498
+ # Gradient clipping
499
+ if self.grad_clip is not None:
500
+ # print("Clipping grad norm")
501
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
502
+
503
+ self.optimizer.step()
504
+
505
+ def configure_optimizers(self):
506
+ if self.scheduler is not None:
507
+ # For reduce LR on plateau, we need to provide more information
508
+ if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
509
+ scheduler_cfg = {
510
+ "scheduler": self.scheduler,
511
+ "interval": "epoch",
512
+ "frequency": 1,
513
+ "monitor": self.monitor,
514
+ "strict": False,
515
+ }
516
+ else:
517
+ scheduler_cfg = self.scheduler
518
+ return [self.optimizer], [scheduler_cfg]
519
+ else:
520
+ return self.optimizer
521
+
522
+ def init_scheduler(self, scheduler, scheduler_params):
523
+ if scheduler is not None:
524
+ if scheduler == "sequential":
525
+ schedulers = []
526
+ milestones = []
527
+ for scheduler_param in scheduler_params:
528
+ sched = utils.import_attr(scheduler_param["name"])(self.optimizer, **scheduler_param["params"])
529
+ schedulers.append(sched)
530
+ milestones.append(scheduler_param["epochs"])
531
+
532
+ # Cumulative sum for milestones
533
+ for i in range(1, len(milestones)):
534
+ milestones[i] = milestones[i - 1] + milestones[i]
535
+
536
+ # Remove last milestone as it is implied by num epochs
537
+ milestones.pop()
538
+
539
+ scheduler = torch.optim.lr_scheduler.SequentialLR(self.optimizer, schedulers, milestones)
540
+ else:
541
+ scheduler = utils.import_attr(scheduler)(self.optimizer, **scheduler_params)
542
+
543
+ return scheduler
src/losses/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/losses/SNRLP.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from src.losses.SNRLosses import SNRLosses
6
+ from src.losses.LogPowerLoss import LogPowerLoss
7
+
8
+
9
+ class SNRLPLoss(nn.Module):
10
+ def __init__(self, snr_loss_name = "snr", neg_weight = 1) -> None:
11
+ super().__init__()
12
+ self.snr_loss = SNRLosses(snr_loss_name)
13
+ #self.lp_loss = LogPowerLoss()
14
+ self.lp_loss = nn.L1Loss()#LogPowerLoss()
15
+ self.neg_weight = neg_weight
16
+
17
+ def forward(self, est: torch.Tensor, gt: torch.Tensor, **kwargs):
18
+ """
19
+ input: (B, C, T) (B, C, T)
20
+ """
21
+ # print(est.shape, gt.shape)
22
+ neg_loss = 0
23
+ pos_loss = 0
24
+
25
+ comp_loss = torch.zeros((est.shape[0]), device=est.device)
26
+ mask = (torch.max(torch.max(torch.abs(gt), dim=2)[0], dim=1)[0] == 0)
27
+ #print("mask", mask)
28
+ # If there's at least one negative sample
29
+ if any(mask):
30
+ est_neg, gt_neg = est[mask], gt[mask]
31
+ neg_loss = self.lp_loss(est_neg, gt_neg)
32
+ comp_loss[mask] = neg_loss * self.neg_weight
33
+
34
+ # If there's at least one positive sample
35
+ if any((~ mask)):
36
+ est_pos, gt_pos = est[~mask], gt[~mask]
37
+ pos_loss = self.snr_loss(est_pos, gt_pos)
38
+
39
+ # Compute_joint_loss
40
+ comp_loss[~mask] = pos_loss
41
+
42
+ return comp_loss
src/metrics/metrics.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torchaudio.functional import resample
5
+
6
+ from torchmetrics.functional import(
7
+ scale_invariant_signal_distortion_ratio as si_sdr,
8
+ scale_invariant_signal_noise_ratio as si_snr,
9
+ signal_noise_ratio as snr)
10
+
11
+ from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility as STOI
12
+ from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality as PESQ
13
+ import numpy as np
14
+ import copy
15
+ from src.losses.MultiResoLoss import MultiResoFuseLoss
16
+ from src.losses.Perceptual_Loss import PLCPALoss
17
+
18
+ def compute_decay(est, mix):
19
+ """
20
+ [*, C, T]
21
+ """
22
+ types = type(est)
23
+ assert type(mix) == types, "All arrays must be the same type"
24
+ if types == np.ndarray:
25
+ est, mix = torch.from_numpy(est), torch.from_numpy(mix)
26
+
27
+ # Ensure that, no matter what, we do not modify the original arrays
28
+ est = est.clone()
29
+ mix = mix.clone()
30
+
31
+ P_est = 10 * torch.log10(torch.sum(est ** 2, dim=-1)) # [*, C]
32
+ P_mix = 10 * torch.log10(torch.sum(mix ** 2, dim=-1))
33
+
34
+ return (P_mix - P_est).mean(dim=-1) # [*]
35
+
36
+ class Metrics(nn.Module):
37
+ def __init__(self, name, fs = 24000, **kwargs) -> None:
38
+ super().__init__()
39
+ self.fs = fs
40
+ self.func = None
41
+ self.name=name
42
+ if name == 'snr':
43
+ self.func = lambda est, gt, mix, self_speech: snr(preds=est, target=gt)
44
+ elif name == 'snr_i':
45
+ self.func = lambda est, gt, mix, self_speech: snr(preds=est, target=gt) - snr(preds=mix, target=gt)
46
+ elif name == 'si_snr':
47
+ self.func = lambda est, gt, mix, self_speech: si_snr(preds=est, target=gt)
48
+ elif name == 'si_snr_i':
49
+ self.func = lambda est, gt, mix, self_speech: si_snr(preds=est, target=gt) - si_snr(preds=mix, target=gt)
50
+ elif name == 'si_sdr':
51
+ self.func = lambda est, gt, mix, self_speech: si_sdr(preds=est, target=gt)
52
+ elif name == 'si_sdr_i':
53
+ self.func = lambda est, gt, mix, self_speech: si_sdr(preds=est, target=gt) - si_sdr(preds=mix, target=gt)
54
+ elif name == 'si_sdr_i_adj':
55
+ self.func = lambda est, gt, mix, self_speech: si_sdr(preds=est, target=gt) - si_sdr(preds=mix, target=gt+self_speech)
56
+ elif name == 'STOI':
57
+ self.func = lambda est, gt, mix, self_speech: STOI(preds=est, target=gt, fs=fs)
58
+ elif name == 'PESQ':
59
+ fs_new = 16000
60
+ self.func = lambda est, gt, mix, self_speech: PESQ(preds=resample(est, fs, fs_new), target=resample(gt, fs, fs_new), fs=fs_new, mode = "nb")
61
+ elif name == 'Multi_Reso_L1':
62
+ mult_ireso_loss = MultiResoFuseLoss(**kwargs)
63
+ self.func = lambda est, gt, mix, self_speech: mult_ireso_loss(est = est, gt = gt)
64
+ elif name == 'PLCPALoss':
65
+ plcpa = PLCPALoss(**kwargs)
66
+ self.func = lambda est, gt, mix, self_speech: plcpa(est = est, gt = gt)
67
+ else:
68
+ raise NotImplementedError(f"Metric {name} not implemented!")
69
+
70
+ def forward(self, est, gt, mix, self_speech=None):
71
+ """
72
+ input: (*, C, T)
73
+ output: (*)
74
+ """
75
+ types = type(est)
76
+ assert type(gt) == types and type(mix) == types, "All arrays must be the same type"
77
+ if types == np.ndarray:
78
+ est, gt, mix = torch.from_numpy(est), torch.from_numpy(gt), torch.from_numpy(mix)
79
+
80
+ # Ensure that, no matter what, we do not modify the original arrays
81
+ est = est.clone()
82
+ gt = gt.clone()
83
+ mix = mix.clone()
84
+
85
+ if self_speech is not None:
86
+ if type(self_speech)==np.ndarray:
87
+ self_speech=torch.from_numpy(self_speech)
88
+ self_speech=self_speech.clone()
89
+
90
+ # print("shape of est in metrics is {}".format(est.shape)) [1, 1, 160000]
91
+ # print("shape of gt is {}".format(gt.shape))
92
+ # print("mix has shape {}".format(mix.shape))
93
+
94
+ # per_channel_metrics = self.func(est=est, gt=gt, mix=mix) # [*, C]
95
+ per_channel_metrics = self.func(est=est, gt=gt, mix=mix, self_speech=self_speech) # [*, C]
96
+
97
+ if self.name == "PLCPALoss":
98
+ return per_channel_metrics[0].mean(dim=-1), per_channel_metrics[1].mean(dim=-1), per_channel_metrics[2].mean(dim=-1)
99
+ else:
100
+ return per_channel_metrics.mean(dim=-1) # [*]
src/models/blocks/model1_block.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ from collections import OrderedDict
4
+ from typing import Dict, List, Optional, Tuple
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from espnet2.torch_utils.get_layer_from_string import get_layer
10
+ from torch.nn import init
11
+ from torch.nn.parameter import Parameter
12
+ import src.utils as utils
13
+
14
+
15
+ class Lambda(nn.Module):
16
+ def __init__(self, lambd):
17
+ super().__init__()
18
+ import types
19
+
20
+ assert type(lambd) is types.LambdaType
21
+ self.lambd = lambd
22
+
23
+ def forward(self, x):
24
+ return self.lambd(x)
25
+
26
+
27
+ class LayerNormPermuted(nn.LayerNorm):
28
+ def __init__(self, *args, **kwargs):
29
+ super(LayerNormPermuted, self).__init__(*args, **kwargs)
30
+
31
+ def forward(self, x):
32
+ """
33
+ Args:
34
+ x: [B, C, T, F]
35
+ """
36
+ x = x.permute(0, 2, 3, 1) # [B, T, F, C]
37
+ x = super().forward(x)
38
+ x = x.permute(0, 3, 1, 2) # [B, C, T, F]
39
+ return x
40
+
41
+
42
+ # Use native layernorm implementation
43
+ class LayerNormalization4D(nn.Module):
44
+ def __init__(self, C, eps=1e-5, preserve_outdim=False):
45
+ super().__init__()
46
+ self.norm = nn.LayerNorm(C, eps=eps)
47
+ self.preserve_outdim = preserve_outdim
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ """
51
+ input: (*, C)
52
+ """
53
+ x = self.norm(x)
54
+ return x
55
+
56
+
57
+ class LayerNormalization4DCF(nn.Module):
58
+ def __init__(self, input_dimension, eps=1e-5):
59
+ assert len(input_dimension) == 2
60
+ Q, C = input_dimension
61
+ super().__init__()
62
+ self.norm = nn.LayerNorm((Q * C), eps=eps)
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ """
66
+ input: (B, T, Q * C)
67
+ """
68
+ x = self.norm(x)
69
+
70
+ return x
71
+
72
+
73
+ class LayerNormalization4D_old(nn.Module):
74
+ def __init__(self, input_dimension, eps=1e-5):
75
+ super().__init__()
76
+ param_size = [1, input_dimension, 1, 1]
77
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
78
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
79
+ init.ones_(self.gamma)
80
+ init.zeros_(self.beta)
81
+ self.eps = eps
82
+
83
+ def forward(self, x):
84
+ if x.ndim == 4:
85
+ _, C, _, _ = x.shape
86
+ stat_dim = (1,)
87
+ else:
88
+ raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
89
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
90
+ std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F]
91
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
92
+ return x_hat
93
+
94
+
95
+ def mod_pad(x, chunk_size, pad):
96
+ # Mod pad the rminput to perform integer number of
97
+ # inferences
98
+ mod = 0
99
+ if (x.shape[-1] % chunk_size) != 0:
100
+ mod = chunk_size - (x.shape[-1] % chunk_size)
101
+
102
+ x = F.pad(x, (0, mod))
103
+ x = F.pad(x, pad)
104
+
105
+ return x, mod
106
+
107
+
108
+ class Attention_STFT_causal(nn.Module):
109
+ def __getitem__(self, key):
110
+ return getattr(self, key)
111
+
112
+ def __init__(
113
+ self,
114
+ emb_dim,
115
+ n_freqs,
116
+ approx_qk_dim=512,
117
+ n_head=4,
118
+ activation="prelu",
119
+ eps=1e-5,
120
+ skip_conn=True,
121
+ use_flash_attention=False,
122
+ dim_feedforward=-1,
123
+ local_context_len=-1,
124
+ # 6
125
+ ):
126
+ super().__init__()
127
+ self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000)
128
+
129
+ self.skip_conn = skip_conn
130
+ self.n_freqs = n_freqs
131
+ self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
132
+ self.n_head = n_head
133
+ self.V_dim = emb_dim // n_head
134
+ self.emb_dim = emb_dim
135
+ assert emb_dim % n_head == 0
136
+ E = self.E
137
+
138
+ self.use_flash_attention = use_flash_attention
139
+
140
+ self.local_context_len = local_context_len
141
+
142
+ self.add_module(
143
+ "attn_conv_Q",
144
+ nn.Sequential(
145
+ nn.Linear(emb_dim, E * n_head), # [B, T, Q, HE]
146
+ get_layer(activation)(),
147
+ # [B, T, Q, H, E] -> [B, H, T, Q, E] -> [B * H, T, Q * E]
148
+ Lambda(
149
+ lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
150
+ .permute(0, 3, 1, 2, 4)
151
+ .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
152
+ ), # (BH, T, Q * E)
153
+ LayerNormalization4DCF((n_freqs, E), eps=eps),
154
+ ),
155
+ )
156
+ self.add_module(
157
+ "attn_conv_K",
158
+ nn.Sequential(
159
+ nn.Linear(emb_dim, E * n_head),
160
+ get_layer(activation)(),
161
+ Lambda(
162
+ lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
163
+ .permute(0, 3, 1, 2, 4)
164
+ .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
165
+ ),
166
+ LayerNormalization4DCF((n_freqs, E), eps=eps),
167
+ ),
168
+ )
169
+ self.add_module(
170
+ "attn_conv_V",
171
+ nn.Sequential(
172
+ nn.Linear(emb_dim, (emb_dim // n_head) * n_head),
173
+ get_layer(activation)(),
174
+ Lambda(
175
+ lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head))
176
+ .permute(0, 3, 1, 2, 4)
177
+ .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head))
178
+ ),
179
+ LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps),
180
+ ),
181
+ )
182
+
183
+ self.dim_feedforward = dim_feedforward
184
+
185
+ if dim_feedforward == -1:
186
+ self.add_module(
187
+ "attn_concat_proj",
188
+ nn.Sequential(
189
+ nn.Linear(emb_dim, emb_dim),
190
+ get_layer(activation)(),
191
+ Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])),
192
+ LayerNormalization4DCF((n_freqs, emb_dim), eps=eps),
193
+ ),
194
+ )
195
+ else:
196
+ self.linear1 = nn.Linear(emb_dim, dim_feedforward)
197
+ self.dropout = nn.Dropout(p=0.1)
198
+ self.activation = nn.ReLU()
199
+ self.linear2 = nn.Linear(dim_feedforward, emb_dim)
200
+ self.dropout2 = nn.Dropout(p=0.1)
201
+ self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps)
202
+
203
+ def _ff_block(self, x):
204
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
205
+ return self.dropout2(x)
206
+
207
+ def get_lookahead_mask(self, seq_len, device):
208
+
209
+ if self.local_context_len == -1:
210
+ mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
211
+
212
+ return mask.detach().to(device)
213
+
214
+ else:
215
+ mask1 = torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1
216
+ mask2 = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=self.local_context_len) == 0
217
+ mask = (mask1 * mask2).transpose(0, 1)
218
+
219
+ return mask.detach().to(device)
220
+
221
+ def forward(self, batch):
222
+ ### input/output B T F C
223
+ # attention
224
+ inputs = batch
225
+ B0, T0, Q0, C0 = batch.shape
226
+
227
+ # positional encoding
228
+ pos_code = self.position_code(batch) # 1, T, embed_dim
229
+ _, T, QC = pos_code.shape
230
+ pos_code = pos_code.reshape(1, T, Q0, C0)
231
+ batch = batch + pos_code
232
+
233
+ Q = self["attn_conv_Q"](batch) # [B', T, Q * C]
234
+ K = self["attn_conv_K"](batch) # [B', T, Q * C]
235
+ V = self["attn_conv_V"](batch) # [B', T, Q * C]
236
+
237
+ emb_dim = Q.shape[-1]
238
+
239
+ local_mask = self.get_lookahead_mask(batch.shape[1], batch.device)
240
+
241
+ attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
242
+ attn_mat.masked_fill_(local_mask == 0, -float("Inf"))
243
+ attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
244
+
245
+ V = torch.matmul(attn_mat, V) # [B', T, Q*C]
246
+ V = V.reshape(-1, T0, V.shape[-1]) # [BH, T, Q * C]
247
+ V = V.transpose(1, 2) # [B', Q * C, T]
248
+
249
+ batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) # [B, H, Q, C, T]
250
+ batch = batch.transpose(2, 3) # [B, H, C, Q, T]
251
+ batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) # [B, HC, Q, T]
252
+ batch = batch.permute(0, 3, 2, 1) # [B, T, Q, C]
253
+
254
+ if self.dim_feedforward == -1:
255
+ batch = self["attn_concat_proj"](batch) # [B, T, Q * C]
256
+ else:
257
+ batch = batch + self._ff_block(batch) # [B, T, Q, C]
258
+ batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3])
259
+ batch = self.norm(batch)
260
+ batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) # [B, T, Q, C])
261
+
262
+ # Add batch if attention is performed
263
+ if self.skip_conn:
264
+ return batch + inputs
265
+ else:
266
+ return batch
267
+
268
+
269
+ class GridNetBlock(nn.Module):
270
+ def __getitem__(self, key):
271
+ return getattr(self, key)
272
+
273
+ def __init__(
274
+ self,
275
+ emb_dim,
276
+ emb_ks,
277
+ emb_hs,
278
+ n_freqs,
279
+ hidden_channels,
280
+ lstm_fold_chunk,
281
+ n_head=4,
282
+ approx_qk_dim=512,
283
+ activation="prelu",
284
+ eps=1e-5,
285
+ pool="mean",
286
+ last=False,
287
+ local_context_len=-1,
288
+ # 6
289
+ ):
290
+ super().__init__()
291
+ bidirectional = True # bidirectional within the intra frame lstm
292
+
293
+ self.global_atten_causal = True
294
+
295
+ self.last = last
296
+
297
+ self.pool = pool
298
+
299
+ self.lstm_fold_chunk = lstm_fold_chunk
300
+ self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
301
+
302
+ self.V_dim = emb_dim // n_head
303
+ self.H = hidden_channels
304
+ in_channels = emb_dim * emb_ks
305
+ self.in_channels = in_channels
306
+ self.n_freqs = n_freqs
307
+
308
+ ## intra RNN can be optimized by conv or linear because the frequence length are not very large
309
+ self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps)
310
+ self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True)
311
+ self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs)
312
+ self.emb_dim = emb_dim
313
+ self.emb_ks = emb_ks
314
+ self.emb_hs = emb_hs
315
+
316
+ # inter RNN
317
+ self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps)
318
+ self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional)
319
+ self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs)
320
+
321
+ # attention
322
+ self.pool_atten_causal = Attention_STFT_causal(
323
+ emb_dim=emb_dim,
324
+ n_freqs=n_freqs,
325
+ approx_qk_dim=approx_qk_dim,
326
+ n_head=n_head,
327
+ activation=activation,
328
+ eps=eps,
329
+ local_context_len=local_context_len,
330
+ )
331
+
332
+ def _unfold_timedomain(self, x):
333
+ BQ, C, T = x.shape
334
+ x = torch.split(x, self.lstm_fold_chunk, dim=-1) # [Num_chunk, BQ, C, 100]
335
+ x = torch.cat(x, dim=0).reshape(-1, BQ, C, self.lstm_fold_chunk) # [Num_chunk, BQ, C, 100]
336
+ x = x.permute(1, 0, 3, 2) # [BQ, Num_chunk, 100, C]
337
+ return x
338
+
339
+ def forward(self, x, init_state=None):
340
+ """GridNetBlock Forward.
341
+
342
+ Args:
343
+ x: [B, C, T, Q]
344
+ out: [B, C, T, Q]
345
+ """
346
+ B, C, old_T, old_Q = x.shape
347
+ T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
348
+ Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
349
+ x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
350
+
351
+ # ===========================Intra RNN start================================
352
+ # define intra RNN
353
+ input_ = x
354
+ intra_rnn = self.intra_norm(input_) # [B, C, T, Q]
355
+ intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
356
+
357
+ intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) # [Q/I, BT, C, I]
358
+ intra_rnn = torch.stack(intra_rnn, dim=0)
359
+ intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) # [BT, CI, Q/I]
360
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, nC*emb_ks]
361
+ self.intra_rnn.flatten_parameters()
362
+
363
+ # apply intra frame LSTM
364
+ intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
365
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
366
+ intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
367
+ intra_rnn = intra_rnn.view([B, T, C, Q])
368
+ intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q]
369
+ intra_rnn = intra_rnn + input_ # [B, C, T, Q]
370
+ intra_rnn = intra_rnn[:, :, :, :old_Q] # [B, C, T, Q]
371
+ Q = old_Q
372
+ # ===========================Intra RNN end================================
373
+
374
+
375
+ # ===========================Inter RNN start================================
376
+ # fold the time domain to chunk
377
+ inter_rnn = self.inter_norm(intra_rnn) # [B, C, T, F]
378
+ inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) # [BF, C, T]
379
+
380
+
381
+ inter_rnn = self._unfold_timedomain(inter_rnn) ### BQ, NUM_CHUNK, CHUNK_SIZE, C
382
+
383
+ BQ, NUM_CHUNK, CHUNKSIZE, C = inter_rnn.shape
384
+
385
+ inter_rnn = inter_rnn.reshape(BQ * NUM_CHUNK, CHUNKSIZE, C) ### BQ* NUM_CHUNK, CHUNK_SIZE, C
386
+ inter_rnn = inter_rnn.transpose(2, 1) # [B, C, T]
387
+ input_ = inter_rnn
388
+
389
+ inter_rnn = torch.split(inter_rnn, self.emb_ks, dim=-1)
390
+
391
+ inter_rnn = torch.stack(inter_rnn, dim=0)
392
+ inter_rnn = inter_rnn.permute(1, 2, 3, 0)
393
+
394
+ BF, C, EO, _T = inter_rnn.shape
395
+ inter_rnn = inter_rnn.reshape(BF, C * EO, _T)
396
+
397
+ inter_rnn = inter_rnn.transpose(1, 2)
398
+
399
+ self.inter_rnn.flatten_parameters()
400
+ inter_rnn, _ = self.inter_rnn(inter_rnn) # [BF, -1, H]
401
+ inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1]
402
+ inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T]
403
+ inter_rnn = inter_rnn + input_ # [BQ* NUM_CHUNK, C, T]
404
+
405
+ inter_rnn = inter_rnn.reshape(B, Q, NUM_CHUNK, C, CHUNKSIZE)
406
+ inter_rnn = inter_rnn.permute(0, 1, 2, 4, 3) # B, Q, NUM_CHUNK, CHUNKSIZE, C
407
+
408
+ input_ = inter_rnn # B, Q, NUM_CHUNK, CHUNKSIZE, C
409
+ if self.pool == "mean":
410
+ inter_rnn = torch.mean(inter_rnn, dim=3) # B, Q, NUM_CHUNK, C
411
+ elif self.pool == "max":
412
+ inter_rnn, _ = torch.max(inter_rnn, dim=3) # B, Q, NUM_CHUNK, C
413
+ else:
414
+ raise ValueError("INvalid pool type!")
415
+ # ===========================Inter RNN end================================
416
+
417
+ # ===========================attention start================================
418
+ inter_rnn = inter_rnn.transpose(1, 2) # B, NUM_CHUNK, Q, C
419
+ inter_rnn = self.pool_atten_causal(inter_rnn) # B T Q C
420
+ inter_rnn = inter_rnn.transpose(1, 2) # B Q T C
421
+
422
+ if self.last == True:
423
+ return inter_rnn, init_state
424
+
425
+ else:
426
+ inter_rnn = inter_rnn.unsqueeze(3)
427
+ inter_rnn = input_ + inter_rnn # B, Q, NUM_CHUNK, CHUNKSIZE, C
428
+
429
+ inter_rnn = inter_rnn.reshape(B, Q, T, C)
430
+ inter_rnn = inter_rnn.permute(0, 3, 2, 1) # B C T Q
431
+ inter_rnn = inter_rnn[..., :old_T, :]
432
+ # ===========================attention end================================
433
+
434
+ return inter_rnn, init_state
src/models/blocks/model2_block.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ from collections import OrderedDict
4
+ from typing import Dict, List, Optional, Tuple
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from espnet2.torch_utils.get_layer_from_string import get_layer
10
+ from torch.nn import init
11
+ from torch.nn.parameter import Parameter
12
+ import src.utils as utils
13
+
14
+
15
+ class Lambda(nn.Module):
16
+ def __init__(self, lambd):
17
+ super().__init__()
18
+ import types
19
+
20
+ assert type(lambd) is types.LambdaType
21
+ self.lambd = lambd
22
+
23
+ def forward(self, x):
24
+ return self.lambd(x)
25
+
26
+
27
+ class LayerNormPermuted(nn.LayerNorm):
28
+ def __init__(self, *args, **kwargs):
29
+ super(LayerNormPermuted, self).__init__(*args, **kwargs)
30
+
31
+ def forward(self, x):
32
+ """
33
+ Args:
34
+ x: [B, C, T, F]
35
+ """
36
+ x = x.permute(0, 2, 3, 1) # [B, T, F, C]
37
+ x = super().forward(x)
38
+ x = x.permute(0, 3, 1, 2) # [B, C, T, F]
39
+ return x
40
+
41
+
42
+ # Use native layernorm implementation
43
+ class LayerNormalization4D(nn.Module):
44
+ def __init__(self, C, eps=1e-5, preserve_outdim=False):
45
+ super().__init__()
46
+ self.norm = nn.LayerNorm(C, eps=eps)
47
+ self.preserve_outdim = preserve_outdim
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ """
51
+ input: (*, C)
52
+ """
53
+ x = self.norm(x)
54
+ return x
55
+
56
+
57
+ class LayerNormalization4DCF(nn.Module):
58
+ def __init__(self, input_dimension, eps=1e-5):
59
+ assert len(input_dimension) == 2
60
+ Q, C = input_dimension
61
+ super().__init__()
62
+ self.norm = nn.LayerNorm((Q * C), eps=eps)
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ """
66
+ input: (B, T, Q * C)
67
+ """
68
+ x = self.norm(x)
69
+
70
+ return x
71
+
72
+
73
+ class LayerNormalization4D_old(nn.Module):
74
+ def __init__(self, input_dimension, eps=1e-5):
75
+ super().__init__()
76
+ param_size = [1, input_dimension, 1, 1]
77
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
78
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
79
+ init.ones_(self.gamma)
80
+ init.zeros_(self.beta)
81
+ self.eps = eps
82
+
83
+ def forward(self, x):
84
+ if x.ndim == 4:
85
+ _, C, _, _ = x.shape
86
+ stat_dim = (1,)
87
+ else:
88
+ raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
89
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
90
+ std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F]
91
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
92
+ return x_hat
93
+
94
+
95
+ def mod_pad(x, chunk_size, pad):
96
+ # Mod pad the rminput to perform integer number of
97
+ # inferences
98
+ mod = 0
99
+ if (x.shape[-1] % chunk_size) != 0:
100
+ mod = chunk_size - (x.shape[-1] % chunk_size)
101
+
102
+ x = F.pad(x, (0, mod))
103
+ x = F.pad(x, pad)
104
+
105
+ return x, mod
106
+
107
+
108
+ class Attention_STFT_causal(nn.Module):
109
+ def __getitem__(self, key):
110
+ return getattr(self, key)
111
+
112
+ def __init__(
113
+ self,
114
+ emb_dim,
115
+ n_freqs,
116
+ approx_qk_dim=512,
117
+ n_head=4,
118
+ activation="prelu",
119
+ eps=1e-5,
120
+ skip_conn=True,
121
+ use_flash_attention=False,
122
+ dim_feedforward=-1,
123
+ ):
124
+ super().__init__()
125
+ self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000)
126
+
127
+ self.skip_conn = skip_conn
128
+ self.n_freqs = n_freqs
129
+ self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
130
+ self.n_head = n_head
131
+ self.V_dim = emb_dim // n_head
132
+ self.emb_dim = emb_dim
133
+ assert emb_dim % n_head == 0
134
+ E = self.E
135
+
136
+ self.add_module(
137
+ "attn_conv_Q",
138
+ nn.Sequential(
139
+ nn.Linear(emb_dim, E * n_head), # [B, T, Q, HE]
140
+ get_layer(activation)(),
141
+ # [B, T, Q, H, E] -> [B, H, T, Q, E] -> [B * H, T, Q * E]
142
+ Lambda(
143
+ lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
144
+ .permute(0, 3, 1, 2, 4)
145
+ .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
146
+ ), # (BH, T, Q * E)
147
+ LayerNormalization4DCF((n_freqs, E), eps=eps),
148
+ ),
149
+ )
150
+ self.add_module(
151
+ "attn_conv_K",
152
+ nn.Sequential(
153
+ nn.Linear(emb_dim, E * n_head),
154
+ get_layer(activation)(),
155
+ Lambda(
156
+ lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
157
+ .permute(0, 3, 1, 2, 4)
158
+ .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
159
+ ),
160
+ LayerNormalization4DCF((n_freqs, E), eps=eps),
161
+ ),
162
+ )
163
+ self.add_module(
164
+ "attn_conv_V",
165
+ nn.Sequential(
166
+ nn.Linear(emb_dim, (emb_dim // n_head) * n_head),
167
+ get_layer(activation)(),
168
+ Lambda(
169
+ lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head))
170
+ .permute(0, 3, 1, 2, 4)
171
+ .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head))
172
+ ),
173
+ LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps),
174
+ ),
175
+ )
176
+
177
+ self.dim_feedforward = dim_feedforward
178
+
179
+ if dim_feedforward == -1:
180
+ self.add_module(
181
+ "attn_concat_proj",
182
+ nn.Sequential(
183
+ nn.Linear(emb_dim, emb_dim),
184
+ get_layer(activation)(),
185
+ Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])),
186
+ LayerNormalization4DCF((n_freqs, emb_dim), eps=eps),
187
+ ),
188
+ )
189
+ else:
190
+ self.linear1 = nn.Linear(emb_dim, dim_feedforward)
191
+ self.dropout = nn.Dropout(p=0.1)
192
+ self.activation = nn.ReLU()
193
+ self.linear2 = nn.Linear(dim_feedforward, emb_dim)
194
+ self.dropout2 = nn.Dropout(p=0.1)
195
+ self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps)
196
+
197
+ def _ff_block(self, x):
198
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
199
+ return self.dropout2(x)
200
+
201
+ def get_lookahead_mask(self, seq_len, device):
202
+ """Creates a binary mask for each sequence which masks future frames.
203
+ Arguments
204
+ ---------
205
+ seq_len: int
206
+ Length of the sequence.
207
+ device: torch.device
208
+ The device on which to create the mask.
209
+ Example
210
+ -------
211
+ >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
212
+ >>> get_lookahead_mask(a.shape[1], device)
213
+ tensor([[0., -inf, -inf],
214
+ [0., 0., -inf],
215
+ [0., 0., 0.]])
216
+ """
217
+ mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
218
+
219
+ return mask.detach().to(device)
220
+
221
+ def forward(self, batch):
222
+ ### input/output B T F C
223
+ # attention
224
+ inputs = batch
225
+ B0, T0, Q0, C0 = batch.shape
226
+ # print("dim of just entering attention stft causal is {}".format(batch.shape))
227
+ # [2, 12, 133, 16]
228
+
229
+ # positional encoding
230
+ pos_code = self.position_code(batch) # 1, T, embed_dim
231
+ # print("pos_code", pos_code.shape)
232
+ _, T, QC = pos_code.shape
233
+ pos_code = pos_code.reshape(1, T, Q0, C0)
234
+ batch = batch + pos_code
235
+
236
+ # print("shape of q is {}".format(Q.shape))
237
+ # print("batch shape is {}".format(batch.shape)) [1, 4800, 16, 133]
238
+
239
+ Q = self["attn_conv_Q"](batch) # [B', T, Q * C]
240
+ K = self["attn_conv_K"](batch) # [B', T, Q * C]
241
+ V = self["attn_conv_V"](batch) # [B', T, Q * C]
242
+
243
+ emb_dim = Q.shape[-1]
244
+
245
+ local_mask = self.get_lookahead_mask(batch.shape[1], batch.device)
246
+
247
+ attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
248
+ attn_mat.masked_fill_(local_mask == 0, -float("Inf"))
249
+ attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
250
+
251
+ V = torch.matmul(attn_mat, V) # [B', T, Q*C]
252
+ V = V.reshape(-1, T0, V.shape[-1]) # [BH, T, Q * C]
253
+ V = V.transpose(1, 2) # [B', Q * C, T]
254
+
255
+ batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) # [B, H, Q, C, T]
256
+ batch = batch.transpose(2, 3) # [B, H, C, Q, T]
257
+ batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) # [B, HC, Q, T]
258
+ batch = batch.permute(0, 3, 2, 1) # [B, T, Q, C]
259
+
260
+ if self.dim_feedforward == -1:
261
+ batch = self["attn_concat_proj"](batch) # [B, T, Q * C]
262
+ else:
263
+ batch = batch + self._ff_block(batch) # [B, T, Q, C]
264
+ batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3])
265
+ batch = self.norm(batch)
266
+ batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) # [B, T, Q, C])
267
+
268
+ # print("dim of output of attention stft causal is {}".format(batch.shape))
269
+ # [2, 12, 133, 16]
270
+
271
+ # Add batch if attention is performed
272
+ if self.skip_conn:
273
+ return batch + inputs
274
+ else:
275
+ return batch
276
+
277
+
278
+ class GridNetBlock(nn.Module):
279
+ def __getitem__(self, key):
280
+ return getattr(self, key)
281
+
282
+ def __init__(
283
+ self,
284
+ emb_dim,
285
+ emb_ks,
286
+ emb_hs,
287
+ n_freqs,
288
+ hidden_channels,
289
+ n_head=4,
290
+ approx_qk_dim=512,
291
+ activation="prelu",
292
+ eps=1e-5,
293
+ pool="mean",
294
+ use_attention=False,
295
+ ):
296
+ super().__init__()
297
+ bidirectional = False
298
+
299
+ self.global_atten_causal = True
300
+
301
+ self.pool = pool
302
+
303
+ self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
304
+
305
+ self.V_dim = emb_dim // n_head
306
+ self.H = hidden_channels
307
+ in_channels = emb_dim * emb_ks
308
+ self.in_channels = in_channels
309
+ self.n_freqs = n_freqs
310
+
311
+ ## intra RNN can be optimized by conv or linear because the frequence length are not very large
312
+ self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps)
313
+ self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True)
314
+ self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs)
315
+ self.emb_dim = emb_dim
316
+ self.emb_ks = emb_ks
317
+ self.emb_hs = emb_hs
318
+
319
+ # inter RNN
320
+ self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps)
321
+ self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional)
322
+ self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs)
323
+
324
+ # attention
325
+ self.use_attention = use_attention
326
+
327
+ if self.use_attention:
328
+ self.pool_atten_causal = Attention_STFT_causal(
329
+ emb_dim=emb_dim,
330
+ n_freqs=n_freqs,
331
+ approx_qk_dim=approx_qk_dim,
332
+ n_head=n_head,
333
+ activation=activation,
334
+ eps=eps,
335
+ )
336
+
337
+ def init_buffers(self, batch_size, device):
338
+ return None
339
+
340
+ # def _unfold_timedomain(self, x):
341
+ # BQ, C, T= x.shape
342
+ # # print("shape of x is {}".format(x.shape))
343
+ # # [117, 16, 4801] for causality testing
344
+ # # 4800 if training
345
+ # x = torch.split(x, self.lstm_fold_chunk, dim=-1) # [Num_chunk, BQ, C, 100]
346
+ # x = torch.cat(x, dim=0).reshape(-1, BQ, C, self.lstm_fold_chunk) # [Num_chunk, BQ, C, 100]
347
+ # x = x.permute(1, 0, 3, 2) # [BQ, Num_chunk, 100, C]
348
+ # return x
349
+
350
+ def forward(self, x, init_state=None):
351
+ """GridNetBlock Forward.
352
+
353
+ Args:
354
+ x: [B, C, T, Q]
355
+ out: [B, C, T, Q]
356
+ """
357
+ B, C, old_T, old_Q = x.shape
358
+ # print("shape of x is {}".format(x.shape))
359
+ # print("old q is {}".format(old_Q))
360
+ # print("dim just entered grid net block is {}".format(x.shape))
361
+ # [1, 16, 4801, 117]
362
+ T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
363
+ Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
364
+ x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
365
+
366
+ # ===========================Intra RNN start================================
367
+ # define intra RNN
368
+ input_ = x
369
+ intra_rnn = self.intra_norm(input_) # [B, C, T, Q]
370
+ intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
371
+
372
+ intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) # [Q/I, BT, C, I]
373
+ intra_rnn = torch.stack(intra_rnn, dim=0)
374
+ intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) # [BT, CI, Q/I]
375
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, nC*emb_ks]
376
+ self.intra_rnn.flatten_parameters()
377
+
378
+ # apply intra frame LSTM
379
+ intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
380
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
381
+ intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
382
+ intra_rnn = intra_rnn.view([B, T, C, Q])
383
+ intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q]
384
+ intra_rnn = intra_rnn + input_ # [B, C, T, Q]
385
+ intra_rnn = intra_rnn[:, :, :, :old_Q] # [B, C, T, Q]
386
+ Q = old_Q
387
+
388
+ # ===========================Intra RNN end================================
389
+
390
+ # print("dim after intra rnn is {}".format(intra_rnn.shape))
391
+ # [1, 16, 4801, 117]
392
+ # [B, C, T, Q]
393
+
394
+ # inter_rnn=intra_rnn
395
+ # ===========================Inter RNN start================================
396
+ # fold the time domain to chunk
397
+ input_ = intra_rnn
398
+
399
+ inter_rnn = self.inter_norm(intra_rnn) # [B, C, T, Q]
400
+ inter_rnn = inter_rnn.transpose(1, 3).reshape(B * Q, T, C)
401
+ # inter_rnn = (
402
+ # inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
403
+ # ) # [BF, C, T]
404
+
405
+ # print("dim of inter rnn is {}".format(inter_rnn.shape))
406
+ # [117, 16, 4801]
407
+
408
+ self.inter_rnn.flatten_parameters()
409
+ # print("inter rnn shape is {}".format(inter_rnn.shape))
410
+ # [133, 400, 16]
411
+ inter_rnn, _ = self.inter_rnn(inter_rnn) # [B * Q, -1, H]
412
+ inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1]
413
+ inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T]
414
+
415
+ _, new_C, new_T = inter_rnn.shape
416
+ inter_rnn = inter_rnn.reshape(B, Q, new_C, new_T)
417
+ inter_rnn = inter_rnn.permute(0, 2, 3, 1)
418
+ # print("shape of inter rnn is {}".format(inter_rnn.shape)) # [133, 16, 4800]
419
+ # print("shape of input_ is {}".format(input_.shape)) # [1, 16, 4800, 133]
420
+ inter_rnn = inter_rnn + input_
421
+ # ===========================Inter RNN end================================
422
+
423
+ # inter rnn shape is [B, C, T, Q]
424
+
425
+ # ===========================attention start================================
426
+ if self.use_attention:
427
+ out = inter_rnn # [B, C, T, Q]
428
+
429
+ inter_rnn = inter_rnn.permute(0, 2, 3, 1)
430
+ inter_rnn = self.pool_atten_causal(inter_rnn) # B T Q C
431
+ inter_rnn = inter_rnn.permute(0, 3, 1, 2) # [B, C, T, Q]
432
+ inter_rnn = out + inter_rnn # B, C, T, Q
433
+
434
+ # Output is inter_rnn by default
435
+ # inter_rnn = inter_rnn.reshape(B, Q, T, C)
436
+ # inter_rnn = inter_rnn.permute(0, 3, 2, 1) # B C T Q
437
+ inter_rnn = inter_rnn[..., :old_T, :]
438
+ # ===========================attention end================================
439
+
440
+ # print("final output inter rnn dimension is {}".format(inter_rnn.shape))
441
+ # print("old T is {}".format(old_T))
442
+
443
+ # print("final output dimension is {}".format(inter_rnn.shape))
444
+ # [2, 16, 4800, 133] [B, C, T, Q]
445
+
446
+ # return inter_rnn, init_state#, [t0 - t0_0, t1 - t0, t2 - t2_0, t3 - t2, t5 - t4, t7 - t6]
447
+ # else:
448
+ return inter_rnn, init_state
src/models/network/model1.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import src.utils as utils
4
+ # from src.models.common.film import FiLM
5
+
6
+
7
+ class FilmLayer(nn.Module):
8
+ def __init__(self, D, C, nF, groups = 1):
9
+ super().__init__()
10
+ self.D = D # speaker dim 256
11
+ self.C = C # latent dim 16
12
+ self.nF = nF
13
+ self.weight = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
14
+ self.bias = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
15
+
16
+ def forward(self, x: torch.Tensor, embedding: torch.Tensor):
17
+ """
18
+ x: (B, D, F, T)
19
+ embedding: (B, D, F)
20
+ """
21
+ B, D, _F, T = x.shape
22
+
23
+ w = self.weight(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
24
+ b = self.bias(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
25
+
26
+ return x * w + b
27
+
28
+
29
+ class LayerNormPermuted(nn.LayerNorm):
30
+ def __init__(self, *args, **kwargs):
31
+ super(LayerNormPermuted, self).__init__(*args, **kwargs)
32
+
33
+ def forward(self, x):
34
+ """
35
+ Args:
36
+ x: [B, C, T, F]
37
+ """
38
+ x = x.permute(0, 2, 3, 1) # [B, T, F, C]
39
+ x = super().forward(x)
40
+ x = x.permute(0, 3, 1, 2) # [B, C, T, F]
41
+ return x
42
+
43
+
44
+ class Conv_Emb_Generator(nn.Module):
45
+ def __init__(
46
+ self,
47
+ block_model_name,
48
+ block_model_params,
49
+ spk_dim=256,
50
+ n_srcs=1,
51
+ n_fft=128,
52
+ latent_dim=16,
53
+ num_inputs=1,
54
+ n_layers=6,
55
+ use_first_ln=True,
56
+ n_imics=1,
57
+ lstm_fold_chunk=400,
58
+ E=2,
59
+ use_speaker_emb=True,
60
+ one_emb=True,
61
+ local_context_len=-1
62
+ # 6
63
+ ):
64
+ super().__init__()
65
+ self.n_srcs = n_srcs
66
+ self.n_layers = n_layers
67
+ self.num_inputs = num_inputs
68
+ assert n_fft % 2 == 0
69
+ n_freqs = n_fft // 2 + 1
70
+ self.n_freqs = n_freqs
71
+ self.latent_dim = latent_dim
72
+
73
+ self.use_speaker_emb=use_speaker_emb
74
+ self.one_emb=one_emb
75
+
76
+ attn_approx_qk_dim=E*n_freqs
77
+
78
+ self.n_fft = n_fft
79
+
80
+ self.eps=1.0e-5
81
+
82
+ t_ksize = 3
83
+ self.t_ksize = t_ksize
84
+ ks, padding = (t_ksize, t_ksize), (0, 1)
85
+
86
+ self.n_imics=n_imics
87
+ if not use_speaker_emb:
88
+ self.n_imics=self.n_imics+1
89
+
90
+ module_list = [nn.Conv2d(2*self.n_imics, latent_dim, ks, padding=padding)]
91
+
92
+ if use_first_ln:
93
+ module_list.append(LayerNormPermuted(latent_dim))
94
+
95
+ self.conv = nn.Sequential(
96
+ *module_list
97
+ )
98
+
99
+ # FiLM layer
100
+ self.embeds = nn.ModuleList([])
101
+
102
+ self.local_context_len=local_context_len
103
+
104
+ self.blocks = nn.ModuleList([])
105
+ for _i in range(n_layers-1):
106
+ self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, last=False, local_context_len=local_context_len, **block_model_params))
107
+ self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, local_context_len=local_context_len, last=True, **block_model_params))
108
+
109
+ if self.use_speaker_emb and not self.one_emb:
110
+ for _i in range(n_layers-1):
111
+ self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1))
112
+ elif self.use_speaker_emb and self.one_emb:
113
+ self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1))
114
+
115
+ def init_buffers(self, batch_size, device):
116
+ conv_buf = torch.zeros(batch_size, 2*self.n_imics, self.t_ksize - 1, self.n_freqs,
117
+ device=device)
118
+
119
+ deconv_buf = torch.zeros(batch_size, self.latent_dim, self.t_ksize - 1, self.n_freqs,
120
+ device=device)
121
+
122
+ block_buffers = {}
123
+ for i in range(len(self.blocks)):
124
+ block_buffers[f'buf{i}'] = None
125
+
126
+ return dict(conv_buf=conv_buf, deconv_buf=deconv_buf,
127
+ block_bufs=block_buffers)
128
+
129
+ def forward(self, current_input: torch.Tensor, embedding: torch.Tensor, input_state, quantized=False) -> torch.Tensor:
130
+ """
131
+ B: batch, M: mic, F: freq bin, C: real/imag, T: time frame
132
+ D: dimension of the embedding vector
133
+ current_input: (B, CM, T, F)
134
+ embedding: (B, D)
135
+ output: (B, S, T, C*F)
136
+ """
137
+ # [B, C, T, F]
138
+ n_batch, _, n_frames, n_freqs = current_input.shape
139
+ batch = current_input
140
+
141
+ if input_state is None:
142
+ input_state = self.init_buffers(current_input.shape[0], current_input.device)
143
+
144
+ conv_buf = input_state['conv_buf']
145
+ gridnet_buf = input_state['block_bufs']
146
+
147
+ if quantized:
148
+ batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
149
+ else:
150
+ batch = torch.cat((conv_buf, batch), dim=2)
151
+
152
+ conv_buf = batch[:, :, -(self.t_ksize - 1):, :]
153
+ batch = self.conv(batch) # [B, D, T, F]
154
+
155
+ if self.use_speaker_emb:
156
+ if not self.one_emb:
157
+ assert len(self.blocks)==self.n_layers
158
+ assert len(self.embeds)==self.n_layers-1
159
+ for ii in range(self.n_layers-1):
160
+ batch = batch.transpose(2, 3)
161
+ if ii > 0:
162
+ batch = self.embeds[ii - 1](batch, embedding)
163
+ batch = batch.transpose(2, 3)
164
+ batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
165
+
166
+ batch = batch.transpose(2, 3)
167
+ batch = self.embeds[-1](batch, embedding)
168
+ batch = batch.transpose(2, 3)
169
+ batch, gridnet_buf[f'buf{self.n_layers-1}'] = self.blocks[self.n_layers-1](batch, gridnet_buf[f'buf{self.n_layers-1}'])
170
+
171
+ else:
172
+ assert len(self.blocks)==self.n_layers
173
+ assert len(self.embeds)==1
174
+ for ii in range(self.n_layers):
175
+ batch = batch.transpose(2, 3)
176
+ if ii == 1:
177
+ batch = self.embeds[ii - 1](batch, embedding)
178
+ batch = batch.transpose(2, 3)
179
+ batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
180
+
181
+ else:
182
+ assert len(self.blocks)==self.n_layers
183
+ for ii in range(self.n_layers):
184
+ batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
185
+
186
+ conversation_emb=batch
187
+
188
+ return conversation_emb, input_state
189
+
190
+
191
+ def edge_mode(self):
192
+ for i in range(len(self.blocks)):
193
+ self.blocks[i].edge_mode()
194
+
195
+ if __name__ == "__main__":
196
+ pass
src/models/network/model2_joint.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import src.utils as utils
4
+ # from src.models.common.film import FiLM
5
+
6
+
7
+ class FilmLayer(nn.Module):
8
+ def __init__(self, D, C, nF, groups = 1):
9
+ super().__init__()
10
+ self.D = D
11
+ self.C = C
12
+ self.nF = nF
13
+ self.weight = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
14
+ self.bias = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
15
+
16
+ def forward(self, x: torch.Tensor, embedding: torch.Tensor):
17
+ """
18
+ x: (B, D, F, T)
19
+ embedding: (B, D, F)
20
+ """
21
+ B, D, _F, T = x.shape
22
+ w = self.weight(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
23
+ b = self.bias(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
24
+
25
+ return x * w + b
26
+
27
+
28
+ class LayerNormPermuted(nn.LayerNorm):
29
+ def __init__(self, *args, **kwargs):
30
+ super(LayerNormPermuted, self).__init__(*args, **kwargs)
31
+
32
+ def forward(self, x):
33
+ """
34
+ Args:
35
+ x: [B, C, T, F]
36
+ """
37
+ x = x.permute(0, 2, 3, 1) # [B, T, F, C]
38
+ x = super().forward(x)
39
+ x = x.permute(0, 3, 1, 2) # [B, C, T, F]
40
+ return x
41
+
42
+
43
+ class TSH(nn.Module):
44
+ def __init__(
45
+ self,
46
+ block_model_name,
47
+ block_model_params,
48
+ spk_dim=256,
49
+ latent_dim=48,
50
+ n_srcs=1,
51
+ n_fft=128,
52
+ num_inputs=1,
53
+ n_layers=6,
54
+ use_first_ln=True,
55
+ n_imics=1,
56
+ lstm_fold_chunk=400,
57
+ stft_chunk_size=200,
58
+ latent_dim_model1=16,
59
+ use_speaker_emb=True,
60
+ use_self_speech_model2=True
61
+ ):
62
+ super().__init__()
63
+ self.n_srcs = n_srcs
64
+ self.n_layers = n_layers
65
+ self.num_inputs = num_inputs
66
+ assert n_fft % 2 == 0
67
+ n_freqs = n_fft // 2 + 1
68
+ self.n_freqs = n_freqs
69
+ self.latent_dim = latent_dim
70
+ self.lstm_fold_chunk=lstm_fold_chunk
71
+ self.stft_chunk_size=stft_chunk_size
72
+
73
+ self.n_fft = n_fft
74
+
75
+ self.eps=1.0e-5
76
+
77
+ t_ksize = 3
78
+ self.t_ksize = t_ksize
79
+ ks, padding = (t_ksize, t_ksize), (0, 1)
80
+
81
+ self.n_imics=n_imics
82
+
83
+ self.use_self_speech_model2=use_self_speech_model2
84
+
85
+ if not use_speaker_emb and use_self_speech_model2:
86
+ self.n_imics=self.n_imics+1
87
+
88
+ module_list = [nn.Conv2d(2*self.n_imics, latent_dim, ks, padding=padding)]
89
+
90
+ if use_first_ln:
91
+ module_list.append(LayerNormPermuted(latent_dim))
92
+
93
+ self.conv = nn.Sequential(
94
+ *module_list
95
+ )
96
+
97
+
98
+ # FiLM layer
99
+ self.embeds = nn.ModuleList([])
100
+
101
+ # Process through a stack of blocks
102
+ self.blocks = nn.ModuleList([])
103
+ for _i in range(n_layers):
104
+ self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, **block_model_params))
105
+
106
+ # Project back to TF-Domain
107
+ self.deconv = nn.ConvTranspose2d(latent_dim, n_srcs * 2, ks, padding=( self.t_ksize - 1, 1))
108
+
109
+ self.latent_dim_model1=latent_dim_model1
110
+
111
+ if latent_dim_model1!=latent_dim:
112
+ self.projection_layer = nn.Conv2d(latent_dim_model1, latent_dim, kernel_size=1)
113
+
114
+ def init_buffers(self, batch_size, device):
115
+ conv_buf = torch.zeros(batch_size, 2*self.n_imics, self.t_ksize - 1, self.n_freqs,
116
+ device=device)
117
+
118
+ deconv_buf = torch.zeros(batch_size, self.latent_dim, self.t_ksize - 1, self.n_freqs,
119
+ device=device)
120
+
121
+ block_buffers = {}
122
+ for i in range(len(self.blocks)):
123
+ block_buffers[f'buf{i}'] = self.blocks[i].init_buffers(batch_size, device)
124
+
125
+ return dict(conv_buf=conv_buf, deconv_buf=deconv_buf,
126
+ block_bufs=block_buffers)
127
+
128
+ def forward(self, current_input: torch.Tensor, embedding: torch.Tensor, input_state, quantized=False) -> torch.Tensor:
129
+ """
130
+ B: batch, M: mic, F: freq bin, C: real/imag, T: time frame
131
+ D: dimension of the embedding vector
132
+ current_input: (B, CM, T, F)
133
+ embedding: (B, D, F)
134
+ output: (B, S, T, C*F)
135
+ """
136
+
137
+ n_batch, _, n_frames, n_freqs = current_input.shape
138
+ batch = current_input
139
+
140
+ if input_state is None:
141
+ input_state = self.init_buffers(current_input.shape[0], current_input.device)
142
+
143
+ conv_buf = input_state['conv_buf']
144
+ gridnet_buf = input_state['block_bufs']
145
+
146
+
147
+ if quantized:
148
+ batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
149
+ else:
150
+ batch = torch.cat((conv_buf, batch), dim=2)
151
+
152
+ conv_buf = batch[:, :, -(self.t_ksize - 1):, :]
153
+ batch = self.conv(batch) # [B, D, T, F]
154
+
155
+ embedding=embedding.transpose(1, 3)
156
+
157
+ for ii in range(self.n_layers):
158
+ if ii==1:
159
+ batch=batch*embedding
160
+ batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
161
+
162
+ deconv_buf = torch.zeros(n_batch, self.latent_dim, self.t_ksize - 1, self.n_freqs,
163
+ device=current_input.device)
164
+ if quantized:
165
+ batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
166
+ else:
167
+ batch = torch.cat(( deconv_buf, batch), dim=2)
168
+
169
+ batch = self.deconv(batch) # [B, n_srcs*C, T, F]
170
+
171
+ batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs]) # [B, n_srcs, 2, n_frames, n_freqs]
172
+ batch = batch.transpose(2, 3).reshape(n_batch, self.n_srcs, n_frames, 2 * n_freqs) # [B, S, T, F]
173
+
174
+
175
+ input_state['conv_buf'] = conv_buf
176
+ input_state['block_bufs'] = gridnet_buf
177
+
178
+ return batch, input_state
179
+
180
+
181
+ def edge_mode(self):
182
+ for i in range(len(self.blocks)):
183
+ self.blocks[i].edge_mode()
184
+
185
+ if __name__ == "__main__":
186
+ pass
src/models/network/net_conversation_joint.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .model1 import Conv_Emb_Generator
5
+ from .model2_joint import TSH
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import copy
9
+
10
+
11
+ def mod_pad(x, chunk_size, pad):
12
+ mod = 0
13
+ if (x.shape[-1] % chunk_size) != 0:
14
+ mod = chunk_size - (x.shape[-1] % chunk_size)
15
+
16
+ x = F.pad(x, (0, mod))
17
+ x = F.pad(x, pad)
18
+
19
+ return x, mod
20
+
21
+
22
+ # A TF-domain network guided by an embedding vector
23
+ class Net_Conversation(nn.Module):
24
+ def __init__(self,
25
+ model1_block_name,
26
+ model1_block_params,
27
+ model2_block_name,
28
+ model2_block_params,
29
+ stft_chunk_size=64,
30
+ stft_pad_size=32,
31
+ stft_back_pad=32,
32
+ num_input_channels=1,
33
+ num_output_channels=1,
34
+ num_sources=1,
35
+ speaker_embed = 256,
36
+ num_layers_model1=3,
37
+ num_layers_model2=3,
38
+ latent_dim_model1=16,
39
+ latent_dim_model2=32,
40
+ use_sp_feats=False,
41
+ use_first_ln=True,
42
+ n_imics=1,
43
+ window="hann",
44
+ lstm_fold_chunk=400,
45
+ E=2,
46
+ use_speaker_emb_model1=True,
47
+ one_emb_model1=True,
48
+ use_self_speech_model2=True,
49
+ local_context_len=-1
50
+ ):
51
+ super(Net_Conversation, self).__init__()
52
+
53
+ assert num_sources == 1
54
+
55
+ # num input/output channels
56
+ self.nI = num_input_channels
57
+ self.nO = num_output_channels
58
+
59
+ # num channels to the TF-network
60
+ num_separator_inputs = self.nI * 2 + use_sp_feats * (3 * (self.nI - 1))
61
+
62
+ self.stft_chunk_size = stft_chunk_size
63
+ self.stft_pad_size = stft_pad_size
64
+ self.stft_back_pad = stft_back_pad
65
+ self.n_srcs = num_sources
66
+ self.use_sp_feats = use_sp_feats
67
+
68
+ # Input conv to convert input audio to a latent representation
69
+ self.nfft = stft_back_pad + stft_chunk_size + stft_pad_size
70
+
71
+ self.nfreqs = self.nfft//2 + 1
72
+
73
+ self.lstm_fold_chunk=lstm_fold_chunk
74
+
75
+ # Construct synthesis/analysis windows (rect)
76
+ if window=="hann":
77
+ window_fn = lambda x: np.hanning(x)
78
+ elif window=="rect":
79
+ window_fn = lambda x: np.ones(x)
80
+ else:
81
+ raise ValueError("Invalid window type!")
82
+
83
+ if ((stft_pad_size) % stft_chunk_size) == 0:
84
+ print("Using perfect STFT windows")
85
+ self.analysis_window = torch.from_numpy(window_fn(self.nfft)).float()
86
+
87
+ # eg. inverse SFTF
88
+ self.synthesis_window = torch.zeros(stft_pad_size + stft_chunk_size).float()
89
+
90
+ A = self.synthesis_window.shape[0]
91
+ B = self.stft_chunk_size
92
+ N = self.analysis_window.shape[0]
93
+
94
+ assert (A % B) == 0
95
+ for i in range(A):
96
+ num = self.analysis_window[N - A + i]
97
+
98
+ denom = 0
99
+ for k in range(A//B):
100
+ denom += (self.analysis_window[N - A + (i % B) + k * B] ** 2)
101
+
102
+ self.synthesis_window[i] = num / denom
103
+ else:
104
+ print("Using imperfect STFT windows")
105
+ self.analysis_window = torch.from_numpy( window_fn(self.nfft) ).float()
106
+ self.synthesis_window = torch.from_numpy( window_fn(stft_chunk_size + stft_pad_size) ).float()
107
+
108
+ self.istft_lookback = 1 + (self.synthesis_window.shape[0] - 1) // self.stft_chunk_size
109
+
110
+ if local_context_len!=-1:
111
+ local_context_len=local_context_len//stft_chunk_size//lstm_fold_chunk
112
+
113
+ self.model1 = Conv_Emb_Generator(
114
+ model1_block_name,
115
+ model1_block_params,
116
+ spk_dim = speaker_embed,
117
+ latent_dim = latent_dim_model1,
118
+ n_srcs = num_output_channels * num_sources,
119
+ n_fft = self.nfft,
120
+ num_inputs = num_separator_inputs,
121
+ n_layers = num_layers_model1,
122
+ use_first_ln=use_first_ln,
123
+ n_imics=n_imics,
124
+ lstm_fold_chunk=lstm_fold_chunk,
125
+ E=E,
126
+ use_speaker_emb=use_speaker_emb_model1,
127
+ one_emb=one_emb_model1,
128
+ local_context_len=local_context_len
129
+ )
130
+
131
+ self.quantized = False
132
+
133
+ self.use_self_speech_model2=use_self_speech_model2
134
+
135
+ self.model2=TSH(
136
+ model2_block_name,
137
+ model2_block_params,
138
+ spk_dim = speaker_embed,
139
+ latent_dim = latent_dim_model2,
140
+ latent_dim_model1=latent_dim_model1,
141
+ n_srcs = num_output_channels * num_sources,
142
+ n_fft = self.nfft,
143
+ num_inputs = num_separator_inputs,
144
+ n_layers = num_layers_model2,
145
+ use_first_ln=use_first_ln,
146
+ n_imics=n_imics,
147
+ lstm_fold_chunk=lstm_fold_chunk,
148
+ stft_chunk_size=stft_chunk_size,
149
+ use_speaker_emb=use_speaker_emb_model1,
150
+ use_self_speech_model2=use_self_speech_model2
151
+ )
152
+
153
+ self.use_speaker_emb_model1=use_speaker_emb_model1
154
+
155
+ def init_buffers(self, batch_size, device):
156
+ buffers = {}
157
+
158
+ buffers['model1_bufs'] = self.model1.init_buffers(batch_size, device)
159
+
160
+ buffers['model2_bufs'] = self.model2.init_buffers(batch_size, device)
161
+
162
+ buffers['istft_buf'] = torch.zeros(batch_size * self.n_srcs * self.nO,
163
+ self.synthesis_window.shape[0],
164
+ self.istft_lookback, device=device)
165
+
166
+ return buffers
167
+
168
+ # compute STFT
169
+ def extract_features(self, x):
170
+ """
171
+ x: (B, M, T)
172
+ returns: (B, C*M, T, F)
173
+ """
174
+ B, M, T = x.shape
175
+
176
+ x = x.reshape(B*M, T)
177
+ x = torch.stft(x, n_fft = self.nfft, hop_length = self.stft_chunk_size,
178
+ win_length = self.nfft, window=self.analysis_window.to(x.device),
179
+ center=False, normalized=False, return_complex=True)
180
+
181
+ x = torch.view_as_real(x) # [B*M, F, T, 2]
182
+ BM, _F, T, C = x.shape
183
+
184
+ x = x.reshape(B, M, _F, T, C) # [B, M, F, T, 2]
185
+
186
+ x = x.permute(0, 4, 1, 3, 2) # [B, 2, M. T, F]
187
+
188
+ x = x.reshape(B, C*M, T, _F)
189
+
190
+ return x
191
+
192
+ def synthesis(self, x, input_state):
193
+ """
194
+ x: (B, S, T, C*F)
195
+ returns: (B, S, t)
196
+ """
197
+ istft_buf = input_state['istft_buf']
198
+
199
+ x = x.transpose(2, 3) # [B, S, CF, T]
200
+
201
+ B, S, CF, T = x.shape
202
+ X = x.reshape(B*S, CF, T)
203
+ X = X.reshape(B*S, 2, -1, T).permute(0, 2, 3, 1) # [BS, F, T, C]
204
+ X = X[..., 0] + 1j * X[..., 1]
205
+
206
+ x = torch.fft.irfft(X, dim=1) # [BS, iW, T]
207
+ x = x[:, -self.synthesis_window.shape[0]:] # [BS, oW, T]
208
+
209
+ # Apply synthesis window
210
+ x = x * self.synthesis_window.unsqueeze(0).unsqueeze(-1).to(x.device)
211
+
212
+ oW = self.synthesis_window.shape[0]
213
+
214
+ # Concatenate blocks from previous IFFTs
215
+ x = torch.cat([istft_buf, x], dim=-1)
216
+ istft_buf = x[..., -istft_buf.shape[1]:] # Update buffer
217
+
218
+ # Get full signal
219
+ x = F.fold(x, output_size=(self.stft_chunk_size * x.shape[-1] + (oW - self.stft_chunk_size), 1),
220
+ kernel_size=(oW, 1), stride=(self.stft_chunk_size, 1)) # [BS, 1, t]
221
+
222
+ x = x[:, :, -T * self.stft_chunk_size - self.stft_pad_size: - self.stft_pad_size]
223
+ x = x.reshape(B, S, -1) # [B, S, t]
224
+
225
+ input_state['istft_buf'] = istft_buf
226
+
227
+ return x, input_state
228
+
229
+
230
+ def predict_model1(self, x, input_state, speaker_embedding, pad=True):
231
+ """
232
+ B: batch
233
+ M: mic
234
+ t: time step (time-domain)
235
+ x: (B, M, t)
236
+ R: real or imaginary
237
+ """
238
+
239
+ mod = 0
240
+ if pad:
241
+ pad_size = (self.stft_back_pad, self.stft_pad_size)
242
+ x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size)
243
+
244
+ # Time-domain to TF-domain
245
+ x = self.extract_features(x) # [B, RM, T, F]
246
+
247
+ if speaker_embedding is not None:
248
+ speaker_embedding=speaker_embedding.unsqueeze(2)
249
+
250
+ conversation_emb, input_state['model1_bufs'] = self.model1(x, speaker_embedding, input_state['model1_bufs'], self.quantized)
251
+
252
+ return conversation_emb, input_state
253
+
254
+ def predict_model2(self, x, conversation_emb, input_state, pad=True):
255
+ """
256
+ B: batch
257
+ M: mic
258
+ t: time step (time-domain)
259
+ x: (B, M, t)
260
+ R: real or imaginary
261
+ """
262
+ mod = 0
263
+ if pad:
264
+ pad_size = (self.stft_back_pad, self.stft_pad_size)
265
+ x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size)
266
+
267
+ x = self.extract_features(x)
268
+
269
+ x, input_state['model2_bufs']=self.model2(x, conversation_emb, input_state['model2_bufs'], self.quantized)
270
+
271
+ # TF-domain to time-domain
272
+ x, next_state = self.synthesis(x, input_state) # [B, S * M, t]
273
+
274
+ if mod != 0:
275
+ x = x[:, :, :-mod]
276
+
277
+ return x, next_state
278
+
279
+
280
+ def forward(self, inputs, input_state = None, pad=True):
281
+ x = inputs['mixture']
282
+
283
+ start_idx_input=inputs['start_idx']
284
+ end_idx_input=inputs['end_idx']
285
+
286
+ assert ((end_idx_input - start_idx_input) % self.stft_chunk_size) == 0
287
+
288
+ # Snap start and end to chunk
289
+ start_idx_input = (start_idx_input // self.stft_chunk_size) * self.stft_chunk_size
290
+ end_idx_input = (end_idx_input // self.stft_chunk_size) * self.stft_chunk_size
291
+
292
+ B, M, t=x.shape
293
+
294
+ audio_range=torch.tensor([start_idx_input, end_idx_input]).to(x.device)
295
+ audio_range = audio_range.unsqueeze(0).repeat(B, 1)
296
+
297
+ spk_embed = inputs['embed']
298
+ self_speech=None
299
+
300
+ if not self.use_speaker_emb_model1:
301
+ self_speech=inputs['self_speech']
302
+
303
+ combined_audio = torch.cat((x, self_speech), dim=1)
304
+ x=combined_audio
305
+
306
+ if input_state is None:
307
+ input_state = self.init_buffers(x.shape[0], x.device)
308
+
309
+ B, M, t = x.shape
310
+
311
+ # enter slow model
312
+ conversation_emb, input_state = self.predict_model1(x, input_state, spk_embed, pad=pad) # [B, F, T, C]
313
+
314
+ # slice conv embedding and corresponding audio
315
+ B, _F, T, C = conversation_emb.shape
316
+ conversation_emb = conversation_emb.permute(0, 1, 3, 2) # [B, F, C, T]
317
+ conversation_emb = torch.roll(conversation_emb, 1, dims=-1)
318
+ conversation_emb[..., 0] = 0
319
+ conversation_emb = conversation_emb.flatten(0,3).unsqueeze(1) # [*, 1]
320
+ multiplier = torch.tile(conversation_emb, (1, self.lstm_fold_chunk)) # [*, L]
321
+ multiplier = multiplier.reshape(B, _F, C, T, self.lstm_fold_chunk).flatten(3,4) # [B, F, C, T*L]
322
+ multiplier = multiplier.permute(0, 1, 3, 2) # [B, F, T*L, C]
323
+
324
+ slicing_length=end_idx_input-start_idx_input+self.stft_back_pad+self.stft_pad_size
325
+
326
+ padded_start=start_idx_input-self.stft_back_pad
327
+ padded_end=end_idx_input+self.stft_pad_size
328
+
329
+ pad_left=max(-padded_start, 0)
330
+ pad_right=max(padded_end-t, 0)
331
+
332
+ actual_start=max(padded_start, 0)
333
+ actual_end=min(padded_end, t)
334
+
335
+ if self.use_self_speech_model2:
336
+ sliced_x=x[:, :, actual_start:actual_end]
337
+ else:
338
+ x_no_self_speech=inputs["mixture"]
339
+ sliced_x=x_no_self_speech[:, :, actual_start:actual_end]
340
+
341
+ padding = (pad_left, pad_right, 0, 0, 0, 0)
342
+
343
+ sliced_x=F.pad(sliced_x, padding, "constant", 0)
344
+
345
+ converted_start_idx=start_idx_input//self.stft_chunk_size
346
+ converted_end_idx=end_idx_input//self.stft_chunk_size
347
+
348
+ sliced_emb=multiplier[:, :, converted_start_idx:converted_end_idx, :]
349
+
350
+ assert sliced_x.shape[2]==slicing_length
351
+ assert sliced_emb.shape[2]==(slicing_length-self.stft_back_pad-self.stft_pad_size)//self.stft_chunk_size
352
+
353
+ model2_output, input_state = self.predict_model2(sliced_x, sliced_emb, input_state, pad=False)
354
+ model2_output = model2_output.reshape(B, self.n_srcs, self.nO, model2_output.shape[-1])
355
+
356
+ return {'output': model2_output[:, 0], 'next_state': input_state, 'audio_range': audio_range}
357
+
358
+
359
+ if __name__ == "__main__":
360
+ pass
src/train_joint.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The main training script for training on synthetic data
3
+ """
4
+
5
+ import torch
6
+ import torch.utils.data
7
+ import torch.nn as nn
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ import multiprocessing
13
+ import time
14
+
15
+ import numpy as np
16
+ import src.utils as utils
17
+ from src.training.tain_val import train_epoch, test_epoch
18
+ import shutil
19
+ import sys
20
+
21
+ import wandb
22
+
23
+ VAL_SEED = 0
24
+ CURRENT_EPOCH = 0
25
+
26
+ def seed_from_epoch(seed):
27
+ global CURRENT_EPOCH
28
+
29
+ utils.seed_all(seed + CURRENT_EPOCH)
30
+
31
+ def print_metrics(metrics: list):
32
+ input_sisdr = np.array([x['input_si_sdr'] for x in metrics])
33
+ sisdr = np.array([x['si_sdr'] for x in metrics])
34
+
35
+ print("Average Input SI-SDR: {:03f}, Average Output SI-SDR: {:03f}, Average SI-SDRi: {:03f}".format(np.mean(input_sisdr), np.mean(sisdr), np.mean(sisdr - input_sisdr)))
36
+
37
+
38
+ def train(args: argparse.Namespace):
39
+ """
40
+ Resolve the network to be trained
41
+ """
42
+ # Fix random seeds
43
+ utils.seed_all(args.seed)
44
+ os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
45
+
46
+ # Turn on deterministic algorithms if specified (Note: slower training).
47
+ if torch.cuda.is_available():
48
+ if args.use_nondeterministic_cudnn:
49
+ torch.backends.cudnn.deterministic = False
50
+ else:
51
+ torch.backends.cudnn.deterministic = True
52
+
53
+ # Load experiment description
54
+ with open(args.config, 'rb') as f:
55
+ params = json.load(f)
56
+
57
+ # Initialize datasets
58
+ data_train = utils.import_attr(params['train_dataset'])(**params['train_data_args'], split='train')
59
+ data_val = utils.import_attr(params['val_dataset'])(**params['val_data_args'], split='val')
60
+
61
+ # Set up the device and workers
62
+ use_cuda = True
63
+ device = torch.device('cuda' if use_cuda else 'cpu')
64
+ print("Using device {}".format('cuda' if use_cuda else 'cpu'))
65
+
66
+ # Set multiprocessing params
67
+ num_workers = min(multiprocessing.cpu_count(), params['num_workers'])
68
+ kwargs = {
69
+ 'num_workers': num_workers,
70
+ 'worker_init_fn': lambda x: seed_from_epoch(args.seed),
71
+ 'pin_memory': False
72
+ } if use_cuda else {}
73
+
74
+ # Set up data loaders
75
+ train_loader = torch.utils.data.DataLoader(data_train,
76
+ batch_size=params['batch_size'],
77
+ shuffle=True,
78
+ **kwargs)
79
+
80
+ kwargs['worker_init_fn'] = lambda x: utils.seed_all(VAL_SEED)
81
+ test_loader = torch.utils.data.DataLoader(data_val,
82
+ batch_size=params['eval_batch_size'],
83
+ **kwargs)
84
+
85
+ # Initialize HL module
86
+ hl_module = utils.import_attr(params['pl_module'])(**params['pl_module_args'])
87
+ hl_module.model.to(device)
88
+
89
+ # Get run name from run dir
90
+ run_name = os.path.basename(args.run_dir.rstrip('/'))
91
+ checkpoints_dir = os.path.join(args.run_dir, 'checkpoints')
92
+
93
+ # Set up checkpoints
94
+ if not os.path.exists(checkpoints_dir):
95
+ os.makedirs(checkpoints_dir)
96
+
97
+ # Copy json
98
+ shutil.copyfile(args.config, os.path.join(args.run_dir, 'config.json'))
99
+
100
+ # Check if a model state path exists for this model, if it does, load it
101
+ best_path = os.path.join(checkpoints_dir, 'best.pt')
102
+ state_path = os.path.join(checkpoints_dir, 'last.pt')
103
+ if args.best and os.path.exists(best_path):
104
+ print("load best state path .....")
105
+ hl_module.load_state(best_path)
106
+
107
+ elif os.path.exists(state_path):
108
+ print("load state path .....")
109
+ hl_module.load_state(state_path)
110
+
111
+ start_epoch = hl_module.epoch
112
+
113
+ if "project_name" in params.keys():
114
+ project_name = params["project_name"]
115
+ else:
116
+ project_name = "AcousticBubble"
117
+ # Initialize wandb
118
+ # print(project_name)
119
+ wandb_run = wandb.init(
120
+ project=project_name,
121
+ name=run_name,
122
+ notes='Example of a note',
123
+ tags=['speech', 'audio', 'embedded-systems']
124
+ )
125
+
126
+ # Training loop
127
+ try:
128
+ # Go over remaining epochs
129
+ for epoch in range(start_epoch, params['epochs']):
130
+ global CURRENT_EPOCH, VAL_SEED
131
+ CURRENT_EPOCH = epoch
132
+ seed_from_epoch(args.seed)
133
+
134
+ hl_module.on_epoch_start()
135
+
136
+ current_lr = hl_module.get_current_lr()
137
+ print("CURRENT learning rate: {:0.08f}".format(current_lr))
138
+
139
+ print("[TRAINING]")
140
+
141
+ # Run testing step
142
+
143
+ t1 = time.time()
144
+ train_loss = train_epoch(hl_module, train_loader, device)
145
+ t2 = time.time()
146
+ print(f"Train epoch time: {t2 - t1:02f}s")
147
+
148
+ print("\nTrain set: Average Loss: {:.4f}\n".format(train_loss))
149
+
150
+ print()
151
+ if np.isnan(train_loss):
152
+ raise ValueError("Got NAN in training")
153
+ utils.seed_all(VAL_SEED)
154
+
155
+ # Run testing step
156
+
157
+ print("[TESTING]")
158
+
159
+ test_loss = test_epoch(hl_module, test_loader, device)
160
+
161
+ print("\nTest set: Average Loss: {:.4f}\n".format(test_loss))
162
+
163
+ hl_module.on_epoch_end(best_path, wandb_run)
164
+ hl_module.dump_state(state_path)
165
+
166
+ print()
167
+ print("=" * 25, "FINISHED EPOCH", epoch, "=" * 25)
168
+ print()
169
+
170
+ except KeyboardInterrupt:
171
+ print("Interrupted")
172
+ except Exception as _:
173
+ import traceback
174
+ traceback.print_exc()
175
+
176
+ if __name__ == '__main__':
177
+ parser = argparse.ArgumentParser()
178
+ # Experiment Params
179
+ parser.add_argument('--config', type=str,
180
+ help='Path to experiment config')
181
+
182
+ parser.add_argument('--run_dir', type=str,
183
+ help='Path to experiment directory')
184
+
185
+ parser.add_argument('--best', action='store_true',
186
+ help="load from best checkpoint instead of last checkpoint")
187
+
188
+ # Randomization Params
189
+ parser.add_argument('--seed', type=int, default=10,
190
+ help='Random seed for reproducibility')
191
+ parser.add_argument('--use_nondeterministic_cudnn',
192
+ action='store_true',
193
+ help="If using cuda, chooses whether or not to use \
194
+ non-deterministic cudDNN algorithms. Training will be\
195
+ faster, but the final results may differ slighty.")
196
+
197
+ # wandb params
198
+ parser.add_argument('--project_name',
199
+ type=str,
200
+ default='AcousticBubble',
201
+ help='Project name that shows up on wandb')
202
+ train(parser.parse_args())
src/training/tain_val.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The main training script for training on synthetic data
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ import os
9
+ import tqdm
10
+
11
+
12
+ def to_device(batch, device):
13
+ if type(batch) == torch.Tensor:
14
+ return batch.to(device)
15
+ elif type(batch) == dict:
16
+ for k in batch:
17
+ batch[k] = to_device(batch[k], device)
18
+ return batch
19
+ elif type(batch) in [list, tuple]:
20
+ batch = [to_device(x, device) for x in batch]
21
+ return batch
22
+ else:
23
+ return batch
24
+
25
+ def test_epoch(hl_module, test_loader, device) -> float:
26
+ """
27
+ Evaluate the network.
28
+ """
29
+ hl_module.eval()
30
+
31
+ test_loss = 0
32
+ num_elements = 0
33
+
34
+ num_batches = len(test_loader)
35
+ pbar = tqdm.tqdm(total=num_batches)
36
+
37
+ with torch.no_grad():
38
+ for batch_idx, batch in enumerate(test_loader):
39
+ batch = to_device(batch, device)
40
+
41
+ loss, B = hl_module.validation_step(batch, batch_idx)
42
+ #print(loss.item(), B)
43
+ test_loss += (loss.item() * B)
44
+ num_elements += B
45
+
46
+ pbar.set_postfix(loss='%.05f'%(loss.item()) )
47
+ pbar.update()
48
+
49
+ return test_loss / num_elements
50
+
51
+ def train_epoch(hl_module, train_loader, device) -> float:
52
+ """
53
+ Train a single epoch.
54
+ """
55
+ # Set the model to training.
56
+ hl_module.train()
57
+
58
+ # Training loop
59
+ train_loss = 0
60
+ num_elements = 0
61
+
62
+ num_batches = len(train_loader)
63
+ pbar = tqdm.tqdm(total=num_batches)
64
+
65
+ for batch_idx, batch in enumerate(train_loader):
66
+ batch = to_device(batch, device)
67
+
68
+ # Reset grad
69
+ hl_module.reset_grad()
70
+
71
+ # Forward pass
72
+ loss, B = hl_module.training_step(batch, batch_idx)
73
+
74
+ # Backpropagation
75
+ loss.backward(retain_graph=False)
76
+ hl_module.backprop()
77
+
78
+ # Save losses
79
+ loss = loss.detach()
80
+ train_loss += (loss.item() * B)
81
+ num_elements += B
82
+ # if batch_idx % 20 == 0:
83
+ # print(loss.item(), B)
84
+ # print('{}/{}'.format(batch_idx, num_batches))
85
+ pbar.set_postfix(loss='%.05f'%(loss.item()) )
86
+ pbar.update()
87
+
88
+ return train_loss / num_elements
src/utils.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import importlib
4
+ import json
5
+
6
+ import librosa
7
+ import soundfile as sf
8
+ import torch
9
+ import torchaudio
10
+ import math
11
+ import torch.nn as nn
12
+
13
+
14
+ class PositionalEncoding(nn.Module):
15
+ """This class implements the absolute sinusoidal positional encoding function.
16
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
17
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
18
+ Arguments
19
+ ---------
20
+ input_size: int
21
+ Embedding dimension.
22
+ max_len : int, optional
23
+ Max length of the input sequences (default 2500).
24
+ Example
25
+ -------
26
+ >>> a = torch.rand((8, 120, 512))
27
+ >>> enc = PositionalEncoding(input_size=a.shape[-1])
28
+ >>> b = enc(a)
29
+ >>> b.shape
30
+ torch.Size([1, 120, 512])
31
+ """
32
+
33
+ def __init__(self, input_size, max_len=2500):
34
+ super().__init__()
35
+ if input_size % 2 != 0:
36
+ raise ValueError(f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})")
37
+ self.max_len = max_len
38
+ pe = torch.zeros(self.max_len, input_size, requires_grad=False)
39
+ positions = torch.arange(0, self.max_len).unsqueeze(1).float()
40
+ denominator = torch.exp(torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size))
41
+
42
+ pe[:, 0::2] = torch.sin(positions * denominator)
43
+ pe[:, 1::2] = torch.cos(positions * denominator)
44
+ pe = pe.unsqueeze(0)
45
+ self.register_buffer("pe", pe)
46
+
47
+ def forward(self, x):
48
+ """
49
+ Arguments
50
+ ---------
51
+ x : tensor
52
+ Input feature shape (batch, time, fea)
53
+ """
54
+ return self.pe[:, : x.size(1)].clone().detach()
55
+
56
+
57
+ def count_parameters(model):
58
+ """
59
+ Count the number of parameters in a PyTorch model.
60
+
61
+ Parameters:
62
+ model (torch.nn.Module): The PyTorch model.
63
+
64
+ Returns:
65
+ int: Number of parameters in the model.
66
+ """
67
+ N_param = sum(p.numel() for p in model.parameters())
68
+ print(f"Model params number {N_param/1e6} M")
69
+
70
+
71
+ def import_attr(import_path):
72
+ module, attr = import_path.rsplit(".", 1)
73
+ return getattr(importlib.import_module(module), attr)
74
+
75
+
76
+ class Params:
77
+ """Class that loads hyperparameters from a json file.
78
+ Example:
79
+ ```
80
+ params = Params(json_path)
81
+ print(params.learning_rate)
82
+ params.learning_rate = 0.5 # change the value of learning_rate in params
83
+ ```
84
+ """
85
+
86
+ def __init__(self, json_path):
87
+ with open(json_path) as f:
88
+ params = json.load(f)
89
+ self.__dict__.update(params)
90
+
91
+ def save(self, json_path):
92
+ with open(json_path, "w") as f:
93
+ json.dump(self.__dict__, f, indent=4)
94
+
95
+ def update(self, json_path):
96
+ """Loads parameters from json file"""
97
+ with open(json_path) as f:
98
+ params = json.load(f)
99
+ self.__dict__.update(params)
100
+
101
+ @property
102
+ def dict(self):
103
+ """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
104
+ return self.__dict__
105
+
106
+
107
+ def load_net_torch(expriment_config, return_params=False):
108
+ params = Params(expriment_config)
109
+ params.pl_module_args["slow_model_ckpt"] = None
110
+ params.pl_module_args["use_dp"] = False
111
+ params.pl_module_args["prev_ckpt"] = None
112
+ pl_module = import_attr(params.pl_module)(**params.pl_module_args)
113
+
114
+ with open(expriment_config) as f:
115
+ params = json.load(f)
116
+
117
+ if return_params:
118
+ return pl_module, params
119
+ else:
120
+ return pl_module
121
+
122
+
123
+ def load_net(expriment_config, return_params=False):
124
+ params = Params(expriment_config)
125
+ params.pl_module_args["use_dp"] = False
126
+ pl_module = import_attr(params.pl_module)(**params.pl_module_args)
127
+
128
+ with open(expriment_config) as f:
129
+ params = json.load(f)
130
+
131
+ if return_params:
132
+ return pl_module, params
133
+ else:
134
+ return pl_module
135
+
136
+
137
+ def load_pretrained(run_dir, return_params=False, map_location="cpu", use_last=False):
138
+ config_path = os.path.join(run_dir, "config.json")
139
+
140
+ pl_module, params = load_net(config_path, return_params=True)
141
+
142
+ # Get all "best" checkpoints
143
+ if use_last:
144
+ name = "last.pt"
145
+ else:
146
+ name = "best.pt"
147
+ ckpt_path = os.path.join(run_dir, f"checkpoints/{name}")
148
+
149
+ if not os.path.exists(ckpt_path):
150
+ raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
151
+
152
+ print("Loading checkpoint from", ckpt_path)
153
+
154
+ # Load checkpoint
155
+ # state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
156
+ pl_module.load_state(ckpt_path, map_location)
157
+ print("Loaded module at epoch", pl_module.epoch)
158
+
159
+ if return_params:
160
+ return pl_module, params
161
+ else:
162
+ return pl_module
163
+
164
+
165
+ def load_pretrained_with_last(run_dir, return_params=False, map_location="cpu", use_last=False):
166
+ config_path = os.path.join(run_dir, "config.json")
167
+
168
+ pl_module, params = load_net(config_path, return_params=True)
169
+
170
+ # Get all "best" checkpoints
171
+ if use_last:
172
+ name = "last.pt"
173
+ else:
174
+ name = "best.pt"
175
+ ckpt_path = os.path.join(run_dir, f"checkpoints/{name}")
176
+
177
+ if not os.path.exists(ckpt_path):
178
+ raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
179
+
180
+ print("Loading checkpoint from", ckpt_path)
181
+
182
+ # Load checkpoint
183
+ # state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
184
+ pl_module.load_state(ckpt_path, map_location)
185
+ print("Loaded module at epoch", pl_module.epoch)
186
+
187
+ if return_params:
188
+ return pl_module, params
189
+ else:
190
+ return pl_module
191
+
192
+
193
+ def load_pretrained2(run_dir, return_params=False, map_location="cpu"):
194
+ config_path = os.path.join(run_dir, "config.json")
195
+ pl_module, params = load_net(config_path, return_params=True)
196
+
197
+ ckpt_path = os.path.join(run_dir, "checkpoints", "best.pt")
198
+ print("Loading checkpoint from", ckpt_path)
199
+
200
+ # Load checkpoint
201
+ # state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
202
+ pl_module.load_state(ckpt_path)
203
+
204
+ if return_params:
205
+ return pl_module, params
206
+ else:
207
+ return pl_module
208
+
209
+
210
+ def load_torch_pretrained(run_dir, return_params=False, map_location="cpu", model_epoch="best"):
211
+ config_path = os.path.join(run_dir, "config.json")
212
+
213
+ print(config_path)
214
+ pl_module, params = load_net_torch(config_path, return_params=True)
215
+
216
+ # Get all "best" checkpoints
217
+ ckpt_path = os.path.join(run_dir, f"checkpoints/{model_epoch}.pt")
218
+
219
+ if not os.path.exists(ckpt_path):
220
+ raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
221
+
222
+ print("Loading checkpoint from", ckpt_path)
223
+
224
+ # Load checkpoint
225
+ # state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
226
+ pl_module.load_state(ckpt_path, map_location)
227
+ print("Loaded module at epoch", pl_module.epoch)
228
+
229
+ if return_params:
230
+ return pl_module, params
231
+ else:
232
+ return pl_module
233
+
234
+
235
+ def read_audio_file(file_path, sr):
236
+ """
237
+ Reads audio file to system memory.
238
+ """
239
+ return librosa.core.load(file_path, mono=False, sr=sr)[0]
240
+
241
+
242
+ def read_audio_file_torch(file_path, downsample=1, input_mean=False):
243
+ waveform, sample_rate = torchaudio.load(file_path)
244
+ if downsample > 1:
245
+ waveform = torchaudio.functional.resample(waveform, sample_rate, sample_rate // downsample)
246
+
247
+ if waveform.shape[0] > 1 and input_mean == True:
248
+ waveform = torch.mean(waveform, dim=0)
249
+ waveform = waveform.unsqueeze(0)
250
+
251
+ elif waveform.shape[0] > 1 and input_mean == "L":
252
+ waveform = waveform[0:1, ...]
253
+
254
+ elif waveform.shape[0] > 1 and input_mean == "R":
255
+ waveform = waveform[1:2, ...]
256
+
257
+ return waveform
258
+
259
+
260
+ def write_audio_file(file_path, data, sr, subtype="PCM_16"):
261
+ """
262
+ Writes audio file to system memory.
263
+ @param file_path: Path of the file to write to
264
+ @param data: Audio signal to write (n_channels x n_samples)
265
+ @param sr: Sampling rate
266
+ """
267
+ sf.write(file_path, data.T, sr, subtype)
268
+
269
+
270
+ def read_json(path):
271
+ with open(path, "rb") as f:
272
+ return json.load(f)
273
+
274
+
275
+ import random
276
+ import numpy as np
277
+
278
+
279
+ def seed_all(seed):
280
+ random.seed(seed)
281
+ np.random.seed(seed)
282
+ torch.manual_seed(seed)
283
+
284
+ if torch.cuda.is_available():
285
+ torch.cuda.manual_seed(seed)