Upload folder using huggingface_hub
Browse files- .gitignore +2 -0
- LICENSE +21 -0
- README.md +54 -3
- config/model_config.json +92 -0
- eval.py +346 -0
- requirements.txt +202 -0
- src/datasets/joint_training_dataset.py +441 -0
- src/datasets/noise.py +202 -0
- src/hl_module/joint_train_hl_module_new.py +543 -0
- src/losses/.DS_Store +0 -0
- src/losses/SNRLP.py +42 -0
- src/metrics/metrics.py +100 -0
- src/models/blocks/model1_block.py +434 -0
- src/models/blocks/model2_block.py +448 -0
- src/models/network/model1.py +196 -0
- src/models/network/model2_joint.py +186 -0
- src/models/network/net_conversation_joint.py +360 -0
- src/train_joint.py +202 -0
- src/training/tain_val.py +88 -0
- src/utils.py +285 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb/
|
| 2 |
+
__pycache__/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Guilin Hu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,54 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Proactive Hearing Assistants that Isolate Egocentric Conversations
|
| 2 |
+
|
| 3 |
+
## More Information
|
| 4 |
+
|
| 5 |
+
For more information, please refer to our website: [https://proactivehearing.cs.washington.edu/](https://proactivehearing.cs.washington.edu/).
|
| 6 |
+
|
| 7 |
+
## Abstract
|
| 8 |
+
|
| 9 |
+
We introduce proactive hearing assistants that automatically identify and separate the wearer’s conversation partners, without requiring explicit prompts. Our system operates on egocentric binaural audio and uses the wearer’s self-speech as an anchor, leveraging turn-taking behavior and dialogue dynamics to infer conversational partners and suppress others. To enable real-time, on-device operation, we propose a dual-model architecture: a lightweight streaming model runs every 12.5 ms for low-latency extraction of the conversation partners, while a slower model runs less frequently to capture longer-range conversational dynamics. Results on real-world 2- and 3-speaker conversation test sets, collected with binaural egocentric hardware from 11 participants totaling 6.8 hours, show generalization in identifying and isolating conversational partners in multi-conversation settings. Our work marks a step toward hearing assistants that adapt proactively to conversational dynamics and engagement.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Training and Evaluation
|
| 13 |
+
|
| 14 |
+
### 1. Installing Requirements
|
| 15 |
+
|
| 16 |
+
Before training or evaluating the model, please create an environment and install all dependencies:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
pip install -r requirements.txt
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### 2. Model Training
|
| 23 |
+
|
| 24 |
+
To train the model, run:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
python src/train_joint.py --config <path_to_config> --run_dir <path_to_model_checkpoint>
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
To resume training, make sure that <path_to_model_checkpoint> points to the same directory used previously, and rerun the command above.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
### 3. Model Evaluation
|
| 34 |
+
|
| 35 |
+
To evaluate the model, run:
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
python eval.py <path to testing dataset> <path to model checkpoint> --use_cuda --save
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## Citation
|
| 43 |
+
|
| 44 |
+
If you use our work, please cite:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
@inproceedings{hu2025proactive,
|
| 48 |
+
title={Proactive Hearing Assistants that Isolate Egocentric Conversations},
|
| 49 |
+
author={Hu, Guilin and Itani, Malek and Chen, Tuochao and Gollakota, Shyamnath},
|
| 50 |
+
booktitle={Proceedings of the 2025 Conference on Empirical Methods in Natural Language Processing},
|
| 51 |
+
pages={25377--25394},
|
| 52 |
+
year={2025}
|
| 53 |
+
}
|
| 54 |
+
```
|
config/model_config.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"project_name": "magic_hear",
|
| 3 |
+
"pl_module": "src.hl_module.joint_train_hl_module_new.PLModule",
|
| 4 |
+
"pl_module_args": {
|
| 5 |
+
"freeze_model1": false,
|
| 6 |
+
"metrics": [
|
| 7 |
+
"snr_i",
|
| 8 |
+
"si_snr_i",
|
| 9 |
+
"si_sdr_i"
|
| 10 |
+
],
|
| 11 |
+
"model": "src.models.network.net_conversation_joint.Net_Conversation",
|
| 12 |
+
"model_params": {
|
| 13 |
+
"model1_block_name": "src.models.blocks.model1_block.GridNetBlock",
|
| 14 |
+
"num_layers_model1": 6,
|
| 15 |
+
"latent_dim_model1": 32,
|
| 16 |
+
"use_speaker_emb_model1": false,
|
| 17 |
+
"use_self_speech_model2": false,
|
| 18 |
+
"one_emb_model1": true,
|
| 19 |
+
"model1_block_params": {
|
| 20 |
+
"emb_ks": 2,
|
| 21 |
+
"emb_hs": 2,
|
| 22 |
+
"hidden_channels": 64,
|
| 23 |
+
"n_head": 4
|
| 24 |
+
},
|
| 25 |
+
"model2_block_name": "src.models.blocks.model2_block.GridNetBlock",
|
| 26 |
+
"num_layers_model2": 6,
|
| 27 |
+
"latent_dim_model2": 32,
|
| 28 |
+
"lstm_fold_chunk": 80,
|
| 29 |
+
"model2_block_params": {
|
| 30 |
+
"emb_ks": 1,
|
| 31 |
+
"emb_hs": 1,
|
| 32 |
+
"hidden_channels": 64,
|
| 33 |
+
"n_head": 4,
|
| 34 |
+
"use_attention": false
|
| 35 |
+
},
|
| 36 |
+
"stft_chunk_size": 200,
|
| 37 |
+
"stft_pad_size": 32,
|
| 38 |
+
"stft_back_pad": 32,
|
| 39 |
+
"num_input_channels": 1,
|
| 40 |
+
"num_output_channels": 1,
|
| 41 |
+
"num_sources": 1,
|
| 42 |
+
"use_sp_feats": false,
|
| 43 |
+
"use_first_ln": true,
|
| 44 |
+
"n_imics": 1,
|
| 45 |
+
"window": "rect",
|
| 46 |
+
"E": 2
|
| 47 |
+
},
|
| 48 |
+
"loss": "src.losses.SNRLP.SNRLPLoss",
|
| 49 |
+
"loss_params": {
|
| 50 |
+
"snr_loss_name": "snr",
|
| 51 |
+
"neg_weight": 100
|
| 52 |
+
},
|
| 53 |
+
"optimizer": "torch.optim.AdamW",
|
| 54 |
+
"optimizer_params": {
|
| 55 |
+
"lr": 2e-3
|
| 56 |
+
},
|
| 57 |
+
"scheduler": "torch.optim.lr_scheduler.ReduceLROnPlateau",
|
| 58 |
+
"scheduler_params": {
|
| 59 |
+
"mode": "min",
|
| 60 |
+
"patience": 4,
|
| 61 |
+
"factor": 0.5,
|
| 62 |
+
"min_lr": 1e-6
|
| 63 |
+
},
|
| 64 |
+
"sr": 16000,
|
| 65 |
+
"grad_clip": 1,
|
| 66 |
+
"use_dp": true
|
| 67 |
+
},
|
| 68 |
+
"train_dataset": "src.datasets.joint_training_dataset.Dataset",
|
| 69 |
+
"train_data_args": {
|
| 70 |
+
"input_dir": [],
|
| 71 |
+
"output_conversation": 1,
|
| 72 |
+
"batch_size": 4,
|
| 73 |
+
"clean_embed": true,
|
| 74 |
+
"random_audio_length": 160000,
|
| 75 |
+
"required_first_speaker_as_self_speech": true,
|
| 76 |
+
"spk_emb_exist": false
|
| 77 |
+
},
|
| 78 |
+
"val_dataset": "src.datasets.joint_training_dataset.Dataset",
|
| 79 |
+
"val_data_args": {
|
| 80 |
+
"input_dir": [],
|
| 81 |
+
"output_conversation": 1,
|
| 82 |
+
"batch_size": 4,
|
| 83 |
+
"clean_embed": true,
|
| 84 |
+
"random_audio_length": 160000,
|
| 85 |
+
"required_first_speaker_as_self_speech": true,
|
| 86 |
+
"spk_emb_exist": false
|
| 87 |
+
},
|
| 88 |
+
"epochs": 130,
|
| 89 |
+
"batch_size": 4,
|
| 90 |
+
"eval_batch_size": 4,
|
| 91 |
+
"num_workers": 12
|
| 92 |
+
}
|
eval.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.metrics.metrics import Metrics
|
| 2 |
+
import src.utils as utils
|
| 3 |
+
import argparse
|
| 4 |
+
import os, json, glob
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torchaudio
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import copy
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torchmetrics.functional import signal_noise_ratio as snr
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def mod_pad(x, chunk_size, pad):
|
| 17 |
+
mod = 0
|
| 18 |
+
if (x.shape[-1] % chunk_size) != 0:
|
| 19 |
+
mod = chunk_size - (x.shape[-1] % chunk_size)
|
| 20 |
+
|
| 21 |
+
x = F.pad(x, (0, mod))
|
| 22 |
+
x = F.pad(x, pad)
|
| 23 |
+
|
| 24 |
+
return x, mod
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LayerNormPermuted(nn.LayerNorm):
|
| 28 |
+
def __init__(self, *args, **kwargs):
|
| 29 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
x: [B, C, T, F]
|
| 35 |
+
"""
|
| 36 |
+
x = x.permute(0, 2, 3, 1) # [B, T, F, C]
|
| 37 |
+
x = super().forward(x)
|
| 38 |
+
x = x.permute(0, 3, 1, 2) # [B, C, T, F]
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def save_audio_file_torch(file_path, wavform, sample_rate=16000, rescale=False):
|
| 43 |
+
if rescale:
|
| 44 |
+
wavform = wavform / torch.max(wavform) * 0.9
|
| 45 |
+
torchaudio.save(file_path, wavform, sample_rate)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_mixture_and_gt(curr_dir, rng, SHIFT_VALUE=0, noise_audio_list=[]):
|
| 49 |
+
metadata2 = utils.read_json(os.path.join(curr_dir, "metadata.json"))
|
| 50 |
+
diags = metadata2["target_dialogue"]
|
| 51 |
+
|
| 52 |
+
if os.path.exists(os.path.join(curr_dir, "self_speech.wav")):
|
| 53 |
+
self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, "self_speech.wav"), 1)
|
| 54 |
+
elif os.path.exists(os.path.join(curr_dir, "self_speech_original.wav")):
|
| 55 |
+
self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, "self_speech_original.wav"), 1)
|
| 56 |
+
|
| 57 |
+
other_speech = torch.zeros_like(self_speech)
|
| 58 |
+
|
| 59 |
+
for i in range(len(diags) - 1):
|
| 60 |
+
wav = utils.read_audio_file_torch(os.path.join(curr_dir, f"target_speech{i}.wav"), 1)
|
| 61 |
+
other_speech += wav
|
| 62 |
+
|
| 63 |
+
if os.path.exists(os.path.join(curr_dir, f"intereference.wav")):
|
| 64 |
+
interfere = utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference.wav"), 1)
|
| 65 |
+
else:
|
| 66 |
+
interfere = torch.zeros_like(self_speech)
|
| 67 |
+
interfere += utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference0.wav"), 1)
|
| 68 |
+
interfere += utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference1.wav"), 1)
|
| 69 |
+
|
| 70 |
+
gt = self_speech + other_speech
|
| 71 |
+
tgt_snr = rng.uniform(-10, 10)
|
| 72 |
+
interfere = scale_noise_to_snr(gt, interfere, tgt_snr)
|
| 73 |
+
|
| 74 |
+
mixture = gt + interfere
|
| 75 |
+
|
| 76 |
+
if noise_audio_list != []:
|
| 77 |
+
print("added noise")
|
| 78 |
+
noise_audio = noise_sample(noise_audio_list, mixture.shape[-1], rng)
|
| 79 |
+
wham_scale = rng.uniform(0, 1)
|
| 80 |
+
mixture += noise_audio * wham_scale
|
| 81 |
+
|
| 82 |
+
embed_path = os.path.join(curr_dir, "embed.pt")
|
| 83 |
+
if os.path.exists(embed_path):
|
| 84 |
+
embed = torch.load(embed_path, weights_only=False)
|
| 85 |
+
embed = torch.from_numpy(embed)
|
| 86 |
+
else:
|
| 87 |
+
embed = torch.zeros(256)
|
| 88 |
+
|
| 89 |
+
L = mixture.shape[-1]
|
| 90 |
+
|
| 91 |
+
peak = np.abs(mixture).max()
|
| 92 |
+
if peak > 1:
|
| 93 |
+
mixture /= peak
|
| 94 |
+
self_speech /= peak
|
| 95 |
+
gt /= peak
|
| 96 |
+
|
| 97 |
+
inputs = {
|
| 98 |
+
"mixture": mixture.float(),
|
| 99 |
+
"embed": embed.float(),
|
| 100 |
+
"self_speech": self_speech[0:1, :].float(),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
targets = {
|
| 104 |
+
"self": self_speech[0:1, :].numpy(),
|
| 105 |
+
"other": other_speech[0:1, :].numpy(),
|
| 106 |
+
"target": gt[0:1, :].float(),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
return inputs, targets, metadata2
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def scale_utterance(audio, timestamp, rng, db_change=7):
|
| 113 |
+
for start, end in timestamp:
|
| 114 |
+
if rng.uniform(0, 1) < 0.3:
|
| 115 |
+
random_db = rng.uniform(-db_change, db_change)
|
| 116 |
+
amplitude_factor = 10 ** (random_db / 20)
|
| 117 |
+
audio[..., start:end] *= amplitude_factor
|
| 118 |
+
|
| 119 |
+
return audio
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_snr(target, mixture, EPS=1e-9):
|
| 123 |
+
"""
|
| 124 |
+
Computes the average SNR across all channels
|
| 125 |
+
"""
|
| 126 |
+
return snr(mixture, target).mean()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def scale_noise_to_snr(target_speech: torch.Tensor, noise: torch.Tensor, target_snr: float):
|
| 130 |
+
current_snr = get_snr(target_speech, noise + target_speech)
|
| 131 |
+
|
| 132 |
+
pwr = (current_snr - target_snr) / 20
|
| 133 |
+
k = 10**pwr
|
| 134 |
+
|
| 135 |
+
return k * noise
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def run_testcase(model, inputs, device) -> np.ndarray:
|
| 139 |
+
with torch.inference_mode():
|
| 140 |
+
inputs["mixture"] = inputs["mixture"][0:1, ...].unsqueeze(0).to(device)
|
| 141 |
+
inputs["embed"] = inputs["embed"].unsqueeze(0).to(device)
|
| 142 |
+
inputs["self_speech"] = inputs["self_speech"][0:1, ...].unsqueeze(0).to(device)
|
| 143 |
+
|
| 144 |
+
inputs["start_idx"] = 0
|
| 145 |
+
inputs["end_idx"] = inputs["mixture"].shape[-1]
|
| 146 |
+
outputs = model(inputs)
|
| 147 |
+
|
| 148 |
+
output_target = outputs["output"].squeeze(0)
|
| 149 |
+
|
| 150 |
+
final_output = output_target.cpu().numpy()
|
| 151 |
+
|
| 152 |
+
return final_output
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_timestamp_mask(timestamps, mask_shape):
|
| 156 |
+
mask = torch.zeros(mask_shape)
|
| 157 |
+
for s, e in timestamps:
|
| 158 |
+
mask[..., s:e] = 1
|
| 159 |
+
|
| 160 |
+
return mask
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def noise_sample(noise_file_list, audio_length, rng: np.random.RandomState):
|
| 164 |
+
# NOTE: hardcoded. assume noise is 48k and target is 16k
|
| 165 |
+
target_sr = 16000
|
| 166 |
+
|
| 167 |
+
acc_len = 0
|
| 168 |
+
concatenated_audio = None
|
| 169 |
+
while acc_len <= audio_length:
|
| 170 |
+
noise_file = rng.choice(noise_file_list)
|
| 171 |
+
info = torchaudio.info(noise_file)
|
| 172 |
+
noise_sr = info.sample_rate
|
| 173 |
+
|
| 174 |
+
noise_wav, _ = torchaudio.load(noise_file)
|
| 175 |
+
noise_wav = noise_wav[0:1, ...]
|
| 176 |
+
|
| 177 |
+
if noise_sr != target_sr:
|
| 178 |
+
resampler = torchaudio.transforms.Resample(orig_freq=noise_sr, new_freq=target_sr)
|
| 179 |
+
noise_wav = resampler(noise_wav)
|
| 180 |
+
|
| 181 |
+
if concatenated_audio is None:
|
| 182 |
+
concatenated_audio = noise_wav
|
| 183 |
+
else:
|
| 184 |
+
concatenated_audio = torch.cat((concatenated_audio, noise_wav), dim=1)
|
| 185 |
+
|
| 186 |
+
acc_len = concatenated_audio.shape[-1]
|
| 187 |
+
|
| 188 |
+
concatenated_audio = concatenated_audio[..., :audio_length]
|
| 189 |
+
|
| 190 |
+
assert concatenated_audio.shape[1] == audio_length
|
| 191 |
+
|
| 192 |
+
return concatenated_audio
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def main(args: argparse.Namespace):
|
| 196 |
+
device = "cuda" if args.use_cuda else "cpu"
|
| 197 |
+
|
| 198 |
+
# Load model
|
| 199 |
+
model = utils.load_torch_pretrained(args.run_dir).model
|
| 200 |
+
model_name = args.run_dir.split("/")[-1]
|
| 201 |
+
model = model.to(device)
|
| 202 |
+
model.eval()
|
| 203 |
+
|
| 204 |
+
# Initialize metrics
|
| 205 |
+
snr = Metrics("snr")
|
| 206 |
+
snr_i = Metrics("snr_i")
|
| 207 |
+
|
| 208 |
+
si_sdr = Metrics("si_sdr")
|
| 209 |
+
|
| 210 |
+
records = []
|
| 211 |
+
|
| 212 |
+
noise_audio_list = []
|
| 213 |
+
if args.noise_dir is not None:
|
| 214 |
+
noise_audio_sublist = glob.glob(os.path.join(args.noise_dir, "*.wav"))
|
| 215 |
+
if not noise_audio_sublist:
|
| 216 |
+
print("no noise file found")
|
| 217 |
+
noise_audio_list.extend(noise_audio_sublist)
|
| 218 |
+
|
| 219 |
+
for i in range(0, 200):
|
| 220 |
+
rng = np.random.RandomState(i)
|
| 221 |
+
dataset_name = os.path.basename(args.test_dir)
|
| 222 |
+
curr_dir = os.path.join(args.test_dir, "{:05d}".format(i))
|
| 223 |
+
|
| 224 |
+
meta_dir = os.path.join(curr_dir, "metadata.json")
|
| 225 |
+
|
| 226 |
+
if not os.path.exists(meta_dir):
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
inputs, targets, metadata = get_mixture_and_gt(curr_dir, rng, noise_audio_list=noise_audio_list)
|
| 230 |
+
|
| 231 |
+
if inputs is None:
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
self_timestamps = metadata["target_dialogue"][0]["timestamp"]
|
| 235 |
+
|
| 236 |
+
target_speech = targets["target"].cpu().numpy()
|
| 237 |
+
row = {"test_case_index": i}
|
| 238 |
+
mixture = inputs["mixture"].cpu().numpy()
|
| 239 |
+
|
| 240 |
+
self_speech = inputs["self_speech"].squeeze(0).cpu().numpy()
|
| 241 |
+
|
| 242 |
+
inputs["mixture"] = inputs["mixture"][0:1, ...]
|
| 243 |
+
target_speech = target_speech[0:1, ...]
|
| 244 |
+
|
| 245 |
+
output_target = run_testcase(model, inputs, device)
|
| 246 |
+
|
| 247 |
+
self_timestamps = metadata["target_dialogue"][0]["timestamp"]
|
| 248 |
+
self_mask = get_timestamp_mask(self_timestamps, target_speech.shape)
|
| 249 |
+
self_mask[..., : args.sr] = 0
|
| 250 |
+
|
| 251 |
+
if mixture.ndim == 1:
|
| 252 |
+
mixture = mixture[np.newaxis, ...]
|
| 253 |
+
|
| 254 |
+
total_input_sisdr = si_sdr(est=mixture[0:1], gt=target_speech, mix=mixture[0:1]).item()
|
| 255 |
+
total_output_sisdr = si_sdr(est=output_target, gt=target_speech, mix=mixture[0:1]).item()
|
| 256 |
+
|
| 257 |
+
row[f"sisdr_input_total"] = total_input_sisdr
|
| 258 |
+
row[f"sisdr_output_total"] = total_output_sisdr
|
| 259 |
+
|
| 260 |
+
# self
|
| 261 |
+
|
| 262 |
+
self_sisdr_mix = si_sdr(
|
| 263 |
+
est=self_mask * mixture[:1], gt=self_mask * target_speech, mix=self_mask * mixture[:1]
|
| 264 |
+
).item()
|
| 265 |
+
self_sisdr_pred = si_sdr(
|
| 266 |
+
est=self_mask * output_target, gt=self_mask * target_speech, mix=self_mask * mixture[:1]
|
| 267 |
+
).item()
|
| 268 |
+
|
| 269 |
+
row[f"sisdr_mix_self"] = self_sisdr_mix
|
| 270 |
+
row[f"sisdr_pred_self"] = self_sisdr_pred
|
| 271 |
+
|
| 272 |
+
# ======other speaker======
|
| 273 |
+
|
| 274 |
+
other_timestamps = metadata["target_dialogue"][1]["timestamp"]
|
| 275 |
+
if len(metadata["target_dialogue"]) > 2:
|
| 276 |
+
for j in range(2, len(metadata["target_dialogue"])):
|
| 277 |
+
timestamp = metadata["target_dialogue"][j]["timestamp"]
|
| 278 |
+
other_timestamps = other_timestamps + timestamp
|
| 279 |
+
|
| 280 |
+
other_mask = get_timestamp_mask(other_timestamps, target_speech.shape)
|
| 281 |
+
other_mask[..., : args.sr] = 0
|
| 282 |
+
|
| 283 |
+
other_sisdr_mix = si_sdr(
|
| 284 |
+
est=other_mask * mixture[:1], gt=other_mask * target_speech, mix=other_mask * mixture[:1]
|
| 285 |
+
).item()
|
| 286 |
+
other_sisdr_pred = si_sdr(
|
| 287 |
+
est=other_mask * output_target, gt=other_mask * target_speech, mix=other_mask * mixture[:1]
|
| 288 |
+
).item()
|
| 289 |
+
|
| 290 |
+
row[f"sisdr_mix_other"] = other_sisdr_mix
|
| 291 |
+
row[f"sisdr_pred_other"] = other_sisdr_pred
|
| 292 |
+
|
| 293 |
+
print(i)
|
| 294 |
+
records.append(row)
|
| 295 |
+
|
| 296 |
+
if noise_audio_list != []:
|
| 297 |
+
save_folder = f"./result_{dataset_name}_noise/{model_name}/{i}"
|
| 298 |
+
else:
|
| 299 |
+
save_folder = f"./result_{dataset_name}/{model_name}/{i}"
|
| 300 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
if type(self_speech) == np.ndarray:
|
| 303 |
+
self_speech = torch.from_numpy(self_speech)
|
| 304 |
+
|
| 305 |
+
if self_speech.dim() == 1:
|
| 306 |
+
self_speech = self_speech.unsqueeze(0)
|
| 307 |
+
|
| 308 |
+
if args.save:
|
| 309 |
+
save_audio_file_torch(
|
| 310 |
+
f"{save_folder}/mix.wav", torch.from_numpy(mixture[0:1]), sample_rate=args.sr, rescale=False
|
| 311 |
+
)
|
| 312 |
+
save_audio_file_torch(f"{save_folder}/self.wav", self_speech, sample_rate=args.sr, rescale=False)
|
| 313 |
+
save_audio_file_torch(
|
| 314 |
+
f"{save_folder}/output_target.wav", torch.from_numpy(output_target), sample_rate=args.sr, rescale=False
|
| 315 |
+
)
|
| 316 |
+
save_audio_file_torch(
|
| 317 |
+
f"{save_folder}/target_speech.wav", torch.from_numpy(target_speech), sample_rate=args.sr, rescale=False
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
results_df = pd.DataFrame.from_records(records)
|
| 321 |
+
|
| 322 |
+
columns = ["test_case_index"] + [col for col in results_df.columns if col != "test_case_index"]
|
| 323 |
+
results_df = results_df[columns]
|
| 324 |
+
|
| 325 |
+
if noise_audio_list != []:
|
| 326 |
+
results_csv_path = f"./result_{dataset_name}_noise/{model_name}_multi.csv"
|
| 327 |
+
else:
|
| 328 |
+
results_csv_path = f"./result_{dataset_name}/{model_name}_multi.csv"
|
| 329 |
+
results_df.to_csv(results_csv_path, index=False)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
parser = argparse.ArgumentParser()
|
| 334 |
+
parser.add_argument("test_dir", type=str, help="Path to test dataset")
|
| 335 |
+
|
| 336 |
+
parser.add_argument("run_dir", type=str, help="Path to model run checkpoint")
|
| 337 |
+
|
| 338 |
+
parser.add_argument("--sr", type=int, default=16000, help="Project sampling rate")
|
| 339 |
+
|
| 340 |
+
parser.add_argument("--noise_dir", type=str, default=None, help="Wham noise directory")
|
| 341 |
+
|
| 342 |
+
parser.add_argument("--use_cuda", action="store_true", help="Whether to use cuda")
|
| 343 |
+
|
| 344 |
+
parser.add_argument("--save", action="store_true", help="Whether to save output audio")
|
| 345 |
+
|
| 346 |
+
main(parser.parse_args())
|
requirements.txt
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.11.16
|
| 4 |
+
aiosignal==1.3.2
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
antlr4-python3-runtime==4.9.3
|
| 7 |
+
asteroid==0.7.0
|
| 8 |
+
asteroid-filterbanks==0.4.0
|
| 9 |
+
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
| 10 |
+
async-timeout==5.0.1
|
| 11 |
+
attrs==25.3.0
|
| 12 |
+
audioread==3.0.1
|
| 13 |
+
auraloss==0.4.0
|
| 14 |
+
beautifulsoup4==4.13.4
|
| 15 |
+
cached-property==2.0.1
|
| 16 |
+
certifi==2025.1.31
|
| 17 |
+
cffi==1.17.1
|
| 18 |
+
cftime==1.6.4.post1
|
| 19 |
+
charset-normalizer==3.4.1
|
| 20 |
+
ci_sdr==0.0.2
|
| 21 |
+
click==8.1.8
|
| 22 |
+
coloredlogs==15.0.1
|
| 23 |
+
comm @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_comm_1753453984/work
|
| 24 |
+
ConfigArgParse==1.7
|
| 25 |
+
contourpy==1.3.0
|
| 26 |
+
ctc_segmentation==1.7.4
|
| 27 |
+
cycler==0.12.1
|
| 28 |
+
Cython==3.0.12
|
| 29 |
+
DateTime==5.5
|
| 30 |
+
debugpy @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_debugpy_1752827114/work
|
| 31 |
+
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
| 32 |
+
Distance==0.1.3
|
| 33 |
+
docker-pycreds==0.4.0
|
| 34 |
+
editdistance==0.8.1
|
| 35 |
+
einops==0.8.1
|
| 36 |
+
espnet==202412
|
| 37 |
+
espnet-tts-frontend==0.0.3
|
| 38 |
+
eval_type_backport==0.2.2
|
| 39 |
+
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
|
| 40 |
+
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
|
| 41 |
+
fast_bss_eval==0.1.3
|
| 42 |
+
filelock==3.18.0
|
| 43 |
+
flatbuffers==25.2.10
|
| 44 |
+
fonttools==4.57.0
|
| 45 |
+
frozenlist==1.5.0
|
| 46 |
+
fsspec==2025.3.2
|
| 47 |
+
g2p-en==2.1.0
|
| 48 |
+
gdown==5.2.0
|
| 49 |
+
gitdb==4.0.12
|
| 50 |
+
GitPython==3.1.44
|
| 51 |
+
grpcio==1.74.0
|
| 52 |
+
h5py==3.13.0
|
| 53 |
+
huggingface-hub==0.30.2
|
| 54 |
+
humanfriendly==10.0
|
| 55 |
+
hydra-core==1.3.2
|
| 56 |
+
HyperPyYAML==1.2.2
|
| 57 |
+
idna==3.10
|
| 58 |
+
importlib-metadata==4.13.0
|
| 59 |
+
importlib_resources==6.5.2
|
| 60 |
+
inflect==7.5.0
|
| 61 |
+
intervaltree==3.1.0
|
| 62 |
+
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1753749834440/work
|
| 63 |
+
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1701831663892/work
|
| 64 |
+
jaconv==0.4.0
|
| 65 |
+
jamo==0.4.1
|
| 66 |
+
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
| 67 |
+
Jinja2==3.1.6
|
| 68 |
+
jiwer==4.0.0
|
| 69 |
+
joblib==1.4.2
|
| 70 |
+
julius==0.2.7
|
| 71 |
+
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
| 72 |
+
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
|
| 73 |
+
kaldiio==2.18.1
|
| 74 |
+
kiwisolver==1.4.7
|
| 75 |
+
lazy_loader==0.4
|
| 76 |
+
librosa==0.9.2
|
| 77 |
+
lightning-utilities==0.14.3
|
| 78 |
+
llvmlite==0.43.0
|
| 79 |
+
Markdown==3.9
|
| 80 |
+
MarkupSafe==3.0.2
|
| 81 |
+
matplotlib==3.9.4
|
| 82 |
+
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
| 83 |
+
mir_eval==0.8.2
|
| 84 |
+
more-itertools==10.6.0
|
| 85 |
+
mpmath==1.3.0
|
| 86 |
+
msgpack==1.1.0
|
| 87 |
+
multidict==6.4.3
|
| 88 |
+
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
| 89 |
+
netCDF4==1.7.2
|
| 90 |
+
networkx==3.2.1
|
| 91 |
+
nltk==3.9.1
|
| 92 |
+
noisereduce==3.0.3
|
| 93 |
+
numba==0.60.0
|
| 94 |
+
numpy==1.23.5
|
| 95 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 96 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 97 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 98 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 99 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 100 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 101 |
+
nvidia-curand-cu12==10.3.5.147
|
| 102 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 103 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 104 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 105 |
+
nvidia-nccl-cu12==2.21.5
|
| 106 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 107 |
+
nvidia-nvtx-cu12==12.4.127
|
| 108 |
+
omegaconf==2.3.0
|
| 109 |
+
onnxruntime==1.19.2
|
| 110 |
+
openai-whisper==20250625
|
| 111 |
+
opt_einsum==3.4.0
|
| 112 |
+
packaging==24.2
|
| 113 |
+
pandas==2.2.3
|
| 114 |
+
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
| 115 |
+
pb-bss-eval==0.0.2
|
| 116 |
+
pesq==0.0.4
|
| 117 |
+
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
| 118 |
+
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
| 119 |
+
pillow==11.2.1
|
| 120 |
+
platformdirs==4.3.7
|
| 121 |
+
pooch==1.8.2
|
| 122 |
+
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
|
| 123 |
+
propcache==0.3.1
|
| 124 |
+
protobuf==5.29.4
|
| 125 |
+
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663125313/work
|
| 126 |
+
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
| 127 |
+
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
| 128 |
+
pybind11==2.13.6
|
| 129 |
+
pycparser==2.22
|
| 130 |
+
pydantic==2.11.3
|
| 131 |
+
pydantic_core==2.33.1
|
| 132 |
+
pydub==0.25.1
|
| 133 |
+
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1750615794071/work
|
| 134 |
+
pyparsing==3.2.3
|
| 135 |
+
pypinyin==0.44.0
|
| 136 |
+
pyroomacoustics==0.8.3
|
| 137 |
+
PySocks==1.7.1
|
| 138 |
+
pystoi==0.4.1
|
| 139 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
|
| 140 |
+
python-sofa==0.2.0
|
| 141 |
+
pytorch-lightning==2.5.1
|
| 142 |
+
pytorch-ranger==0.1.1
|
| 143 |
+
pytz==2025.2
|
| 144 |
+
pyworld==0.3.5
|
| 145 |
+
PyYAML==6.0.2
|
| 146 |
+
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1749898437650/work
|
| 147 |
+
RapidFuzz==3.13.0
|
| 148 |
+
regex==2024.11.6
|
| 149 |
+
requests==2.32.3
|
| 150 |
+
resampy==0.4.3
|
| 151 |
+
Resemblyzer==0.1.4
|
| 152 |
+
ruamel.yaml==0.18.15
|
| 153 |
+
ruamel.yaml.clib==0.2.12
|
| 154 |
+
safetensors==0.5.3
|
| 155 |
+
scikit-learn==1.6.1
|
| 156 |
+
scipy==1.13.1
|
| 157 |
+
sentencepiece==0.1.97
|
| 158 |
+
sentry-sdk==2.26.0
|
| 159 |
+
setproctitle==1.3.5
|
| 160 |
+
silero-vad==5.1.2
|
| 161 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_six_1753199211/work
|
| 162 |
+
smmap==5.0.2
|
| 163 |
+
sortedcontainers==2.4.0
|
| 164 |
+
soundfile==0.13.1
|
| 165 |
+
soupsieve==2.7
|
| 166 |
+
sox==1.5.0
|
| 167 |
+
soxbindings==1.2.3
|
| 168 |
+
soxr==0.5.0.post1
|
| 169 |
+
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
| 170 |
+
sympy==1.13.1
|
| 171 |
+
tensorboard==2.20.0
|
| 172 |
+
tensorboard-data-server==0.7.2
|
| 173 |
+
tensorboardX==2.6.4
|
| 174 |
+
threadpoolctl==3.6.0
|
| 175 |
+
tiktoken==0.9.0
|
| 176 |
+
tokenizers==0.21.1
|
| 177 |
+
torch==2.6.0
|
| 178 |
+
torch-complex==0.4.4
|
| 179 |
+
torch-optimizer==0.1.0
|
| 180 |
+
torch-stoi==0.2.3
|
| 181 |
+
torchaudio==2.6.0
|
| 182 |
+
torchmetrics==0.11.4
|
| 183 |
+
torchvision==0.21.0
|
| 184 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003328568/work
|
| 185 |
+
tqdm==4.67.1
|
| 186 |
+
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
| 187 |
+
transformers==4.51.3
|
| 188 |
+
triton==3.2.0
|
| 189 |
+
typeguard==4.4.2
|
| 190 |
+
typing==3.7.4.3
|
| 191 |
+
typing-inspection==0.4.0
|
| 192 |
+
typing_extensions==4.13.2
|
| 193 |
+
tzdata==2025.2
|
| 194 |
+
Unidecode==1.3.8
|
| 195 |
+
urllib3==2.4.0
|
| 196 |
+
wandb==0.19.9
|
| 197 |
+
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
|
| 198 |
+
webrtcvad==2.0.10
|
| 199 |
+
Werkzeug==3.1.3
|
| 200 |
+
yarl==1.19.0
|
| 201 |
+
zipp==3.21.0
|
| 202 |
+
zope.interface==7.2
|
src/datasets/joint_training_dataset.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Torch dataset object for synthetically rendered
|
| 3 |
+
spatial data
|
| 4 |
+
"""
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import os, glob
|
| 13 |
+
|
| 14 |
+
import src.utils as utils
|
| 15 |
+
from .noise import WhitePinkBrownAugmentation
|
| 16 |
+
import torchaudio
|
| 17 |
+
from torchmetrics.functional import signal_noise_ratio as snr
|
| 18 |
+
from torch.utils.data._utils.collate import default_collate
|
| 19 |
+
|
| 20 |
+
MAX_LEN = 50
|
| 21 |
+
|
| 22 |
+
def save_audio_file_torch(file_path, wavform, sample_rate = 16000, rescale = False):
|
| 23 |
+
if rescale:
|
| 24 |
+
wavform = wavform/torch.max(wavform)*0.9
|
| 25 |
+
torchaudio.save(file_path, wavform, sample_rate)
|
| 26 |
+
|
| 27 |
+
def perturb_amplitude_db(audio, db_change=10):
|
| 28 |
+
random_db = np.random.uniform(-db_change, db_change)
|
| 29 |
+
amplitude_factor = 10 ** (random_db / 20)
|
| 30 |
+
audio = audio * amplitude_factor
|
| 31 |
+
return audio
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def scale_to_tgt_pwr(audio: np.ndarray, timestamp, tgt_pwr_dB: float, EPS=1e-9):
|
| 35 |
+
segments = []
|
| 36 |
+
for start_time, end_time in timestamp:
|
| 37 |
+
start_time = max(0, start_time)
|
| 38 |
+
end_time = min(audio.size(-1), end_time)
|
| 39 |
+
|
| 40 |
+
segment = audio[..., start_time:end_time]
|
| 41 |
+
segments.append(segment)
|
| 42 |
+
|
| 43 |
+
# Concatenate segments
|
| 44 |
+
concatenated = torch.cat(segments, dim=-1)
|
| 45 |
+
|
| 46 |
+
avg_pwr = torch.mean(concatenated**2)
|
| 47 |
+
avg_pwr_dB = 10 * torch.log10(avg_pwr + EPS)
|
| 48 |
+
scale = 10 ** ((tgt_pwr_dB - avg_pwr_dB) / 20)
|
| 49 |
+
|
| 50 |
+
audio_scaled = scale * audio
|
| 51 |
+
concatenated_scaled=scale*concatenated
|
| 52 |
+
|
| 53 |
+
scaled_pwr_dB = 10 * torch.log10(torch.mean(concatenated_scaled**2) + EPS)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
assert torch.abs(tgt_pwr_dB - scaled_pwr_dB) < 0.1
|
| 57 |
+
|
| 58 |
+
return audio_scaled
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def scale_utterance(audio, timestamp, rng, db_change=7):
|
| 62 |
+
for start, end in timestamp:
|
| 63 |
+
if rng.uniform(0, 1) < 0.3:
|
| 64 |
+
random_db=rng.uniform(-db_change, db_change)
|
| 65 |
+
amplitude_factor = 10 ** (random_db / 20)
|
| 66 |
+
audio[..., start:end] *= amplitude_factor
|
| 67 |
+
|
| 68 |
+
return audio
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_snr(target, mixture, EPS=1e-9):
|
| 72 |
+
"""
|
| 73 |
+
Computes the average SNR across all channels
|
| 74 |
+
"""
|
| 75 |
+
return snr(mixture, target).mean()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def scale_noise_to_snr(target_speech: torch.Tensor, noise: torch.Tensor, target_snr: float):
|
| 79 |
+
"""
|
| 80 |
+
Rescales a BINAURAL noise signal to achieve an average SNR (across both channels) equal to target snr.
|
| 81 |
+
Let k be the noise scaling factor
|
| 82 |
+
SNR_tgt = (SNR_left_scaled + SNR_right_scaled) / 2 = 0.5 * (10 log(S_L^T S_L/S_N^T S_N) - 20 log(k) + 10 log(S_R^T S_R / N_R^T N_R) - 20 log(k))
|
| 83 |
+
= 0.5 * (SNR_left_unscaled + SNR_right_unscaled - 40 log(k)) = avg_snr_initial - 20 log (k)
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
current_snr = get_snr(target_speech, noise + target_speech)
|
| 87 |
+
|
| 88 |
+
pwr = (current_snr - target_snr) / 20
|
| 89 |
+
k = 10 ** pwr
|
| 90 |
+
|
| 91 |
+
return k * noise
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def custom_collate_fn(batch):
|
| 95 |
+
"""
|
| 96 |
+
batch: List of tuples (inputs_dict, targets_dict).
|
| 97 |
+
inputs_dict: Dictionary of inputs like 'mixture', 'embed', etc.
|
| 98 |
+
targets_dict: Dictionary of targets like 'target', 'masked_target', etc.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# Separate inputs and targets
|
| 102 |
+
inputs = [item[0] for item in batch] # item[0] contains the 'inputs' dict
|
| 103 |
+
targets = [item[1] for item in batch] # item[1] contains the 'targets' dict
|
| 104 |
+
|
| 105 |
+
# Process inputs - use default_collate for everything except 'self_timestamp'
|
| 106 |
+
collated_inputs = {}
|
| 107 |
+
for key in inputs[0].keys():
|
| 108 |
+
if key == 'self_timestamp':
|
| 109 |
+
# Handle self_timestamp as a list of lists (variable-length)
|
| 110 |
+
collated_inputs[key] = [item[key] for item in inputs]
|
| 111 |
+
else:
|
| 112 |
+
# For fixed-length tensors, stack them using default_collate
|
| 113 |
+
collated_inputs[key] = default_collate([item[key] for item in inputs])
|
| 114 |
+
|
| 115 |
+
# Process targets (normal fixed-length tensors)
|
| 116 |
+
collated_targets = default_collate(targets)
|
| 117 |
+
|
| 118 |
+
return collated_inputs, collated_targets
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Dataset(torch.utils.data.Dataset):
|
| 122 |
+
"""
|
| 123 |
+
Dataset of mixed waveforms and their corresponding ground truth waveforms
|
| 124 |
+
recorded at different microphone.
|
| 125 |
+
|
| 126 |
+
Data format is a pair of Tensors containing mixed waveforms and
|
| 127 |
+
ground truth waveforms respectively. The tensor's dimension is formatted
|
| 128 |
+
as (n_microphone, duration).
|
| 129 |
+
|
| 130 |
+
Each scenario is represented by a folder. Multiple datapoints are generated per
|
| 131 |
+
scenario. This can be customized using the points_per_scenario parameter.
|
| 132 |
+
"""
|
| 133 |
+
def __init__(self, input_dir, n_mics=1, sr=8000,
|
| 134 |
+
sig_len = 30, downsample = 1,
|
| 135 |
+
split = 'val', output_conversation = 0,
|
| 136 |
+
batch_size = 8,
|
| 137 |
+
clean_embed=False,
|
| 138 |
+
noise_dir = None,
|
| 139 |
+
random_audio_length=800,
|
| 140 |
+
required_first_speaker_as_self_speech=True,
|
| 141 |
+
spk_emb_exist=True,
|
| 142 |
+
amplitude_aug_range=0,
|
| 143 |
+
noise_amplitude_aug_range=7,
|
| 144 |
+
utter_db_aug=7,
|
| 145 |
+
input_mean="L",
|
| 146 |
+
min_snr=-10,
|
| 147 |
+
max_snr=10,
|
| 148 |
+
original_val=False,
|
| 149 |
+
apply_timestamp_aug=False,
|
| 150 |
+
snr_control=True
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
self.dirs = []
|
| 155 |
+
self.spk_emb_exist=spk_emb_exist
|
| 156 |
+
for _dir in input_dir:
|
| 157 |
+
dir_list = sorted(list(Path(_dir).glob('[0-9]*')))
|
| 158 |
+
for dest in dir_list:
|
| 159 |
+
meta_path = os.path.join(dest, 'metadata.json')
|
| 160 |
+
embed_path = os.path.join(dest, 'embed.pt')
|
| 161 |
+
self_speech_path=os.path.join(dest, 'self_speech.wav')
|
| 162 |
+
|
| 163 |
+
if self.spk_emb_exist and os.path.exists(meta_path) and os.path.exists(embed_path):
|
| 164 |
+
self.dirs.append(dest)
|
| 165 |
+
elif not self.spk_emb_exist and os.path.exists(meta_path):
|
| 166 |
+
self.dirs.append(dest)
|
| 167 |
+
|
| 168 |
+
self.noise_dirs = []
|
| 169 |
+
if noise_dir is not None:
|
| 170 |
+
for sub_dir in noise_dir:
|
| 171 |
+
noise_audio_list = glob.glob(os.path.join(sub_dir, '*.wav'))
|
| 172 |
+
if not noise_dir:
|
| 173 |
+
print("no noise file found")
|
| 174 |
+
self.noise_dirs.extend(noise_audio_list)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
self.clean_embed = clean_embed
|
| 178 |
+
self.n_mics = n_mics
|
| 179 |
+
self.sig_len = int(sig_len*sr/downsample)
|
| 180 |
+
self.sr = sr
|
| 181 |
+
self.downsample = downsample
|
| 182 |
+
self.scales = [-3, 3]
|
| 183 |
+
self.output_conversation = output_conversation
|
| 184 |
+
self.apply_timestamp_aug = apply_timestamp_aug
|
| 185 |
+
|
| 186 |
+
# Data augmentation
|
| 187 |
+
### calculate the stat
|
| 188 |
+
self.batch_size = batch_size
|
| 189 |
+
self.split = split
|
| 190 |
+
print(self.split, (len(self.dirs)//batch_size)*batch_size)
|
| 191 |
+
|
| 192 |
+
self.random_audio_length=random_audio_length
|
| 193 |
+
self.required_first_speaker_as_self_speech=required_first_speaker_as_self_speech
|
| 194 |
+
|
| 195 |
+
self.amplitude_aug_range=amplitude_aug_range
|
| 196 |
+
self.noise_amplitude_aug_range=noise_amplitude_aug_range
|
| 197 |
+
|
| 198 |
+
self.pwr_thresh = -60
|
| 199 |
+
self.min_snr=min_snr
|
| 200 |
+
self.max_snr=max_snr
|
| 201 |
+
self.utter_db_aug=utter_db_aug
|
| 202 |
+
self.input_mean=input_mean
|
| 203 |
+
self.original_val=original_val
|
| 204 |
+
self.snr_control=snr_control
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def __len__(self) -> int:
|
| 208 |
+
return (len(self.dirs)//self.batch_size)*self.batch_size
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def noise_sample(self, noise_file_list, audio_length, rng: np.random.RandomState):
|
| 212 |
+
# NOTE: hardcoded. assume noise is 48k and target is 16k
|
| 213 |
+
# noise_audio=utils.read_audio_file_torch(noise_file, 3)
|
| 214 |
+
|
| 215 |
+
target_sr = 16000
|
| 216 |
+
|
| 217 |
+
acc_len=0
|
| 218 |
+
concatenated_audio = None
|
| 219 |
+
while acc_len<=audio_length:
|
| 220 |
+
noise_file=rng.choice(noise_file_list)
|
| 221 |
+
info = torchaudio.info(noise_file)
|
| 222 |
+
noise_sr=info.sample_rate
|
| 223 |
+
|
| 224 |
+
noise_wav, _ = torchaudio.load(noise_file)
|
| 225 |
+
if noise_wav.shape[0]>1 and self.input_mean=="L":
|
| 226 |
+
noise_wav=noise_wav[0:1, ...]
|
| 227 |
+
elif noise_wav.shape[0]>1 and self.input_mean=="R":
|
| 228 |
+
noise_wav=noise_wav[1:2, ...]
|
| 229 |
+
elif noise_wav.shape[0]>1 and self.input_mean==True:
|
| 230 |
+
noise_wav=torch.mean(noise_wav, dim=0)
|
| 231 |
+
noise_wav=noise_wav.unsqueeze(0)
|
| 232 |
+
|
| 233 |
+
if noise_sr != target_sr:
|
| 234 |
+
resampler = torchaudio.transforms.Resample(orig_freq=noise_sr, new_freq=target_sr)
|
| 235 |
+
noise_wav = resampler(noise_wav)
|
| 236 |
+
|
| 237 |
+
if concatenated_audio is None:
|
| 238 |
+
concatenated_audio = noise_wav
|
| 239 |
+
else:
|
| 240 |
+
concatenated_audio = torch.cat((concatenated_audio, noise_wav), dim=1)
|
| 241 |
+
|
| 242 |
+
acc_len=concatenated_audio.shape[-1]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
concatenated_audio=concatenated_audio[..., :audio_length]
|
| 246 |
+
|
| 247 |
+
assert concatenated_audio.shape[1]==audio_length
|
| 248 |
+
|
| 249 |
+
return concatenated_audio
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 253 |
+
"""
|
| 254 |
+
Returns:
|
| 255 |
+
mixed_data - M x T
|
| 256 |
+
target_voice_data - M x T
|
| 257 |
+
window_idx_one_hot - 1-D
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
if self.split == 'train':
|
| 261 |
+
seed = idx + np.random.randint(1000000)
|
| 262 |
+
else:
|
| 263 |
+
seed = idx
|
| 264 |
+
rng = np.random.RandomState(seed)
|
| 265 |
+
|
| 266 |
+
curr_dir = self.dirs[idx%len(self.dirs)]
|
| 267 |
+
return self.get_mixture_and_gt(curr_dir, rng)
|
| 268 |
+
|
| 269 |
+
def diffuse_speech_pattern(self, audio: torch.Tensor, timestamps: list, rng: np.random.RandomState, beta=8000):
|
| 270 |
+
zero_segments = np.array([timestamps[0][0]] + [timestamps[i+1][0] - timestamps[i][1] for i in range(len(timestamps) - 1)] + [audio.shape[-1] - timestamps[-1][1]])
|
| 271 |
+
total_zeros = sum(zero_segments)
|
| 272 |
+
|
| 273 |
+
# Add noise "diffusion"
|
| 274 |
+
noise = rng.normal(loc=0, scale=beta)
|
| 275 |
+
zero_segments = zero_segments + noise
|
| 276 |
+
|
| 277 |
+
# Ensure all elements are still positive
|
| 278 |
+
zero_segments[zero_segments <= 0] = 1
|
| 279 |
+
|
| 280 |
+
# Normalize so that sum is 1
|
| 281 |
+
zero_segments = zero_segments / zero_segments.sum()
|
| 282 |
+
zero_segments = zero_segments * total_zeros
|
| 283 |
+
|
| 284 |
+
# Floor indices so that we don't exceed audio size
|
| 285 |
+
zero_segments = np.floor(zero_segments).astype(np.int32)
|
| 286 |
+
|
| 287 |
+
assert zero_segments.sum() <= total_zeros
|
| 288 |
+
|
| 289 |
+
# Fill in time stamps
|
| 290 |
+
new_audio = torch.zeros_like(audio)
|
| 291 |
+
start_index = 0
|
| 292 |
+
for z, (s, e) in zip(zero_segments[:-1], timestamps):
|
| 293 |
+
start_index += z
|
| 294 |
+
new_audio[..., start_index:start_index+(e-s)] = audio[..., s:e]
|
| 295 |
+
start_index += (e - s)
|
| 296 |
+
|
| 297 |
+
return new_audio
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def process_audio(self, audio, timestamp, rng, utter_db_aug, tgt_pwr_dB):
|
| 301 |
+
if self.apply_timestamp_aug:
|
| 302 |
+
audio = self.diffuse_speech_pattern(audio, timestamp, rng, beta=16000)
|
| 303 |
+
|
| 304 |
+
if timestamp==[]:
|
| 305 |
+
return audio
|
| 306 |
+
else:
|
| 307 |
+
audio = scale_to_tgt_pwr(audio, timestamp, tgt_pwr_dB)
|
| 308 |
+
audio=scale_utterance(audio, timestamp, rng, utter_db_aug)
|
| 309 |
+
return audio
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def get_mixture_and_gt(self, curr_dir, rng):
|
| 313 |
+
metadata2 = utils.read_json(os.path.join(curr_dir, 'metadata.json'))
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# process self speech
|
| 317 |
+
self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, 'self_speech.wav'), 1, self.input_mean)
|
| 318 |
+
self_speech_original=None
|
| 319 |
+
if os.path.exists(os.path.join(curr_dir, 'self_speech_original.wav')):
|
| 320 |
+
self_speech_original=utils.read_audio_file_torch(os.path.join(curr_dir, 'self_speech_original.wav'), 1, self.input_mean)
|
| 321 |
+
|
| 322 |
+
self_timestamp=metadata2['target_dialogue'][0]['timestamp']
|
| 323 |
+
|
| 324 |
+
if self_speech_original is not None:
|
| 325 |
+
list_of_self=[self_speech, self_speech_original]
|
| 326 |
+
concat_self_speech=torch.cat(list_of_self, dim=0)
|
| 327 |
+
utterance_adj_concat_self=scale_utterance(concat_self_speech, self_timestamp, rng, self.utter_db_aug)
|
| 328 |
+
self_speech=utterance_adj_concat_self[0:1, ...]
|
| 329 |
+
self_speech_original=utterance_adj_concat_self[1:2, ...]
|
| 330 |
+
else:
|
| 331 |
+
self_speech=scale_utterance(self_speech, self_timestamp, rng, self.utter_db_aug)
|
| 332 |
+
|
| 333 |
+
# process interference speech
|
| 334 |
+
if os.path.exists(os.path.join(curr_dir, f'intereference.wav')):
|
| 335 |
+
interfere = utils.read_audio_file_torch(os.path.join(curr_dir, f'intereference.wav'), 1, self.input_mean)
|
| 336 |
+
scale = 0.8
|
| 337 |
+
else:
|
| 338 |
+
interfers = metadata2["interference"]
|
| 339 |
+
interfere = torch.zeros_like(self_speech)
|
| 340 |
+
if os.path.exists(os.path.join(curr_dir, f'intereference0.wav')):
|
| 341 |
+
for i in range(0, len(interfers)):
|
| 342 |
+
current_inter=utils.read_audio_file_torch(os.path.join(curr_dir, f'intereference{i}.wav'), 1, self.input_mean)
|
| 343 |
+
inter_timestamp=metadata2['interference'][i]['timestamp']
|
| 344 |
+
|
| 345 |
+
current_inter=scale_utterance(current_inter, inter_timestamp, rng, self.utter_db_aug)
|
| 346 |
+
interfere += current_inter
|
| 347 |
+
elif os.path.exists(os.path.join(curr_dir, f'interference0.wav')):
|
| 348 |
+
for i in range(0, len(interfers)):
|
| 349 |
+
current_inter= utils.read_audio_file_torch(os.path.join(curr_dir, f'interference{i}.wav'), 1, self.input_mean)
|
| 350 |
+
inter_timestamp=metadata2['interference'][i]['timestamp']
|
| 351 |
+
|
| 352 |
+
current_inter=scale_utterance(current_inter, inter_timestamp, rng, self.utter_db_aug)
|
| 353 |
+
interfere += current_inter
|
| 354 |
+
scale = 1
|
| 355 |
+
|
| 356 |
+
# process other speech
|
| 357 |
+
other_speech = torch.zeros_like(self_speech)
|
| 358 |
+
if self.output_conversation:
|
| 359 |
+
diags = metadata2["target_dialogue"]
|
| 360 |
+
for i in range(len(diags) - 1):
|
| 361 |
+
if os.path.exists(os.path.join(curr_dir, f'target_speech{i}.wav')):
|
| 362 |
+
wav = utils.read_audio_file_torch(os.path.join(curr_dir, f'target_speech{i}.wav'), 1, self.input_mean)
|
| 363 |
+
other_timestamp=metadata2['target_dialogue'][i+1]['timestamp']
|
| 364 |
+
wav=scale_utterance(wav, other_timestamp, rng, self.utter_db_aug)
|
| 365 |
+
other_speech += wav
|
| 366 |
+
|
| 367 |
+
elif os.path.exists(os.path.join(curr_dir, f'other_speech{i}.wav')):
|
| 368 |
+
wav = utils.read_audio_file_torch(os.path.join(curr_dir, f'other_speech{i}.wav'), 1, self.input_mean)
|
| 369 |
+
other_timestamp=metadata2['target_dialogue'][i+1]['timestamp']
|
| 370 |
+
wav=scale_utterance(wav, other_timestamp, rng, self.utter_db_aug)
|
| 371 |
+
other_speech += wav
|
| 372 |
+
else:
|
| 373 |
+
raise Exception("no audio file to load")
|
| 374 |
+
|
| 375 |
+
# add noise, e.g. WHAM
|
| 376 |
+
if self.noise_dirs!=[] and random.random() < 0.3:
|
| 377 |
+
audio_length=interfere.shape[1]
|
| 378 |
+
noise=self.noise_sample(self.noise_dirs, audio_length, rng)
|
| 379 |
+
wham_scale = rng.uniform(0, 1)
|
| 380 |
+
interfere += noise*wham_scale
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
if self_speech_original is not None:
|
| 384 |
+
gt = self_speech_original + other_speech
|
| 385 |
+
else:
|
| 386 |
+
gt = self_speech + other_speech
|
| 387 |
+
|
| 388 |
+
mixture=gt+interfere
|
| 389 |
+
|
| 390 |
+
if self.snr_control==True:
|
| 391 |
+
tgt_snr = rng.uniform(self.min_snr, self.max_snr)
|
| 392 |
+
noise = scale_noise_to_snr(gt, mixture - gt, tgt_snr)
|
| 393 |
+
|
| 394 |
+
mixture = noise + gt
|
| 395 |
+
|
| 396 |
+
noise_augmentor = WhitePinkBrownAugmentation(
|
| 397 |
+
max_white_level=1e-2, # Adjust as needed
|
| 398 |
+
max_pink_level=5e-2, # Adjust as needed
|
| 399 |
+
max_brown_level=5e-2 # Adjust as needed
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
if self.split=="train" and random.random() < 0.3:
|
| 403 |
+
mixture, gt = noise_augmentor(mixture, gt, rng)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
reverb_path = os.path.join(curr_dir, f'embed.pt')
|
| 407 |
+
|
| 408 |
+
if self.spk_emb_exist:
|
| 409 |
+
embed = torch.load(reverb_path, weights_only=False)
|
| 410 |
+
embed = torch.from_numpy(embed)
|
| 411 |
+
else:
|
| 412 |
+
embed=torch.zeros(256)
|
| 413 |
+
|
| 414 |
+
self.output_conversation
|
| 415 |
+
|
| 416 |
+
input_length=self_speech.shape[1]
|
| 417 |
+
|
| 418 |
+
start_idx=rng.randint(input_length-self.random_audio_length)
|
| 419 |
+
end_idx=start_idx+self.random_audio_length
|
| 420 |
+
|
| 421 |
+
# ====peak normalization======
|
| 422 |
+
peak = torch.abs(mixture).max()
|
| 423 |
+
if peak > 1:
|
| 424 |
+
mixture /= peak
|
| 425 |
+
gt /= peak
|
| 426 |
+
self_speech /= peak
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
inputs = {
|
| 430 |
+
'mixture': mixture.float(),
|
| 431 |
+
'embed': embed.float(),
|
| 432 |
+
'self_speech': self_speech[0:1, :].float(),
|
| 433 |
+
'start_idx_list': start_idx,
|
| 434 |
+
'end_idx_list': end_idx
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
targets = {
|
| 438 |
+
'target': gt[0:1, :].float()
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
return inputs, targets
|
src/datasets/noise.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def generate_white_noise(noise_shape, max_level, rng: np.random.RandomState):
|
| 6 |
+
# Choose white noise level
|
| 7 |
+
white_noise_level = max_level * rng.rand()
|
| 8 |
+
# print(white_noise_level)
|
| 9 |
+
# Generate white noise
|
| 10 |
+
white_noise = white_noise_level*torch.from_numpy(rng.normal(0, 1, size=noise_shape)).float()
|
| 11 |
+
|
| 12 |
+
return white_noise
|
| 13 |
+
|
| 14 |
+
def generate_pink_noise(noise_shape, max_level, rng: np.random.RandomState):
|
| 15 |
+
# Choose pink noise level
|
| 16 |
+
pink_noise_level = max_level * rng.rand()
|
| 17 |
+
# print(pink_noise_level)
|
| 18 |
+
|
| 19 |
+
# Generate pink noise
|
| 20 |
+
pink_noise = powerlaw_psd_gaussian(1, noise_shape, random_state = 0)
|
| 21 |
+
pink_noise = pink_noise_level*torch.from_numpy(pink_noise).float()
|
| 22 |
+
|
| 23 |
+
return pink_noise
|
| 24 |
+
|
| 25 |
+
def generate_brown_noise(noise_shape, max_level, rng: np.random.RandomState):
|
| 26 |
+
# Choose brown noise level
|
| 27 |
+
brown_noise_level = max_level * rng.rand()
|
| 28 |
+
# print(brown_noise_level)
|
| 29 |
+
|
| 30 |
+
# Generate brown noise
|
| 31 |
+
brown_noise = powerlaw_psd_gaussian(2, noise_shape, random_state = 0)
|
| 32 |
+
brown_noise = brown_noise_level*torch.from_numpy(brown_noise).float()
|
| 33 |
+
|
| 34 |
+
return brown_noise
|
| 35 |
+
|
| 36 |
+
"""Generate colored noise."""
|
| 37 |
+
|
| 38 |
+
from numpy import sqrt, newaxis, integer
|
| 39 |
+
from numpy.fft import irfft, rfftfreq
|
| 40 |
+
from numpy.random import default_rng, Generator, RandomState
|
| 41 |
+
from numpy import sum as npsum
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def powerlaw_psd_gaussian(exponent, size, fmin=0, random_state=None):
|
| 45 |
+
"""Gaussian (1/f)**beta noise.
|
| 46 |
+
|
| 47 |
+
Based on the algorithm in:
|
| 48 |
+
Timmer, J. and Koenig, M.:
|
| 49 |
+
On generating power law noise.
|
| 50 |
+
Astron. Astrophys. 300, 707-710 (1995)
|
| 51 |
+
|
| 52 |
+
Normalised to unit variance
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
-----------
|
| 56 |
+
|
| 57 |
+
exponent : float
|
| 58 |
+
The power-spectrum of the generated noise is proportional to
|
| 59 |
+
|
| 60 |
+
S(f) = (1 / f)**beta
|
| 61 |
+
flicker / pink noise: exponent beta = 1
|
| 62 |
+
brown noise: exponent beta = 2
|
| 63 |
+
|
| 64 |
+
Furthermore, the autocorrelation decays proportional to lag**-gamma
|
| 65 |
+
with gamma = 1 - beta for 0 < beta < 1.
|
| 66 |
+
There may be finite-size issues for beta close to one.
|
| 67 |
+
|
| 68 |
+
shape : int or iterable
|
| 69 |
+
The output has the given shape, and the desired power spectrum in
|
| 70 |
+
the last coordinate. That is, the last dimension is taken as time,
|
| 71 |
+
and all other components are independent.
|
| 72 |
+
|
| 73 |
+
fmin : float, optional
|
| 74 |
+
Low-frequency cutoff.
|
| 75 |
+
Default: 0 corresponds to original paper.
|
| 76 |
+
|
| 77 |
+
The power-spectrum below fmin is flat. fmin is defined relative
|
| 78 |
+
to a unit sampling rate (see numpy's rfftfreq). For convenience,
|
| 79 |
+
the passed value is mapped to max(fmin, 1/samples) internally
|
| 80 |
+
since 1/samples is the lowest possible finite frequency in the
|
| 81 |
+
sample. The largest possible value is fmin = 0.5, the Nyquist
|
| 82 |
+
frequency. The output for this value is white noise.
|
| 83 |
+
|
| 84 |
+
random_state : int, numpy.integer, numpy.random.Generator, numpy.random.RandomState,
|
| 85 |
+
optional
|
| 86 |
+
Optionally sets the state of NumPy's underlying random number generator.
|
| 87 |
+
Integer-compatible values or None are passed to np.random.default_rng.
|
| 88 |
+
np.random.RandomState or np.random.Generator are used directly.
|
| 89 |
+
Default: None.
|
| 90 |
+
|
| 91 |
+
Returns
|
| 92 |
+
-------
|
| 93 |
+
out : array
|
| 94 |
+
The samples.
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
Examples:
|
| 98 |
+
---------
|
| 99 |
+
|
| 100 |
+
# generate 1/f noise == pink noise == flicker noise
|
| 101 |
+
>>> import colorednoise as cn
|
| 102 |
+
>>> y = cn.powerlaw_psd_gaussian(1, 5)
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
# Make sure size is a list so we can iterate it and assign to it.
|
| 106 |
+
try:
|
| 107 |
+
size = list(size)
|
| 108 |
+
except TypeError:
|
| 109 |
+
size = [size]
|
| 110 |
+
|
| 111 |
+
# The number of samples in each time series
|
| 112 |
+
samples = size[-1]
|
| 113 |
+
|
| 114 |
+
# Calculate Frequencies (we asume a sample rate of one)
|
| 115 |
+
# Use fft functions for real output (-> hermitian spectrum)
|
| 116 |
+
f = rfftfreq(samples)
|
| 117 |
+
|
| 118 |
+
# Validate / normalise fmin
|
| 119 |
+
if 0 <= fmin <= 0.5:
|
| 120 |
+
fmin = max(fmin, 1./samples) # Low frequency cutoff
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError("fmin must be chosen between 0 and 0.5.")
|
| 123 |
+
|
| 124 |
+
# Build scaling factors for all frequencies
|
| 125 |
+
s_scale = f
|
| 126 |
+
ix = npsum(s_scale < fmin) # Index of the cutoff
|
| 127 |
+
if ix and ix < len(s_scale):
|
| 128 |
+
s_scale[:ix] = s_scale[ix]
|
| 129 |
+
s_scale = s_scale**(-exponent/2.)
|
| 130 |
+
|
| 131 |
+
# Calculate theoretical output standard deviation from scaling
|
| 132 |
+
w = s_scale[1:].copy()
|
| 133 |
+
w[-1] *= (1 + (samples % 2)) / 2. # correct f = +-0.5
|
| 134 |
+
sigma = 2 * sqrt(npsum(w**2)) / samples
|
| 135 |
+
|
| 136 |
+
# Adjust size to generate one Fourier component per frequency
|
| 137 |
+
size[-1] = len(f)
|
| 138 |
+
|
| 139 |
+
# Add empty dimension(s) to broadcast s_scale along last
|
| 140 |
+
# dimension of generated random power + phase (below)
|
| 141 |
+
dims_to_add = len(size) - 1
|
| 142 |
+
s_scale = s_scale[(newaxis,) * dims_to_add + (Ellipsis,)]
|
| 143 |
+
|
| 144 |
+
# prepare random number generator
|
| 145 |
+
normal_dist = _get_normal_distribution(random_state)
|
| 146 |
+
|
| 147 |
+
# Generate scaled random power + phase
|
| 148 |
+
sr = normal_dist(scale=s_scale, size=size)
|
| 149 |
+
si = normal_dist(scale=s_scale, size=size)
|
| 150 |
+
|
| 151 |
+
# If the signal length is even, frequencies +/- 0.5 are equal
|
| 152 |
+
# so the coefficient must be real.
|
| 153 |
+
if not (samples % 2):
|
| 154 |
+
si[..., -1] = 0
|
| 155 |
+
sr[..., -1] *= sqrt(2) # Fix magnitude
|
| 156 |
+
|
| 157 |
+
# Regardless of signal length, the DC component must be real
|
| 158 |
+
si[..., 0] = 0
|
| 159 |
+
sr[..., 0] *= sqrt(2) # Fix magnitude
|
| 160 |
+
|
| 161 |
+
# Combine power + corrected phase to Fourier components
|
| 162 |
+
s = sr + 1J * si
|
| 163 |
+
|
| 164 |
+
# Transform to real time series & scale to unit variance
|
| 165 |
+
y = irfft(s, n=samples, axis=-1) / sigma
|
| 166 |
+
|
| 167 |
+
return y
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _get_normal_distribution(random_state):
|
| 171 |
+
normal_dist = None
|
| 172 |
+
if isinstance(random_state, (integer, int)) or random_state is None:
|
| 173 |
+
random_state = default_rng(random_state)
|
| 174 |
+
normal_dist = random_state.normal
|
| 175 |
+
elif isinstance(random_state, (Generator, RandomState)):
|
| 176 |
+
normal_dist = random_state.normal
|
| 177 |
+
else:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
"random_state must be one of integer, numpy.random.Generator, or None"
|
| 180 |
+
"numpy.random.Randomstate"
|
| 181 |
+
)
|
| 182 |
+
return normal_dist
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class WhitePinkBrownAugmentation:
|
| 186 |
+
def __init__(self, max_white_level=1e-3, max_pink_level=5e-3, max_brown_level=5e-3):
|
| 187 |
+
"""
|
| 188 |
+
max_shift: Maximum shift (inclusive) in both directions
|
| 189 |
+
unique: Whether the same shift across channels is unique
|
| 190 |
+
"""
|
| 191 |
+
self.max_white_level = max_white_level
|
| 192 |
+
self.max_pink_level = max_pink_level
|
| 193 |
+
self.max_brown_level = max_brown_level
|
| 194 |
+
|
| 195 |
+
def __call__(self, audio_data, gt_audio, rng: np.random.RandomState):
|
| 196 |
+
wn = generate_white_noise(audio_data.shape, self.max_white_level, rng)
|
| 197 |
+
pn = generate_pink_noise(audio_data.shape, self.max_pink_level, rng)
|
| 198 |
+
bn = generate_brown_noise(audio_data.shape, self.max_brown_level, rng)
|
| 199 |
+
# print("ssss")
|
| 200 |
+
augmented_audio = audio_data + (wn + pn + bn)
|
| 201 |
+
|
| 202 |
+
return augmented_audio, gt_audio
|
src/hl_module/joint_train_hl_module_new.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
from numpy import mean
|
| 9 |
+
from src.metrics.metrics import Metrics
|
| 10 |
+
import src.utils as utils
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FakeModel(nn.Module):
|
| 15 |
+
def __init__(self, model):
|
| 16 |
+
super(FakeModel, self).__init__()
|
| 17 |
+
self.model = model
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PLModule(object):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
model,
|
| 24 |
+
model_params,
|
| 25 |
+
sr,
|
| 26 |
+
optimizer,
|
| 27 |
+
optimizer_params,
|
| 28 |
+
scheduler=None,
|
| 29 |
+
scheduler_params=None,
|
| 30 |
+
loss=None,
|
| 31 |
+
loss_params=None,
|
| 32 |
+
metrics=[],
|
| 33 |
+
slow_model_ckpt=None,
|
| 34 |
+
prev_ckpt=None,
|
| 35 |
+
grad_clip=None,
|
| 36 |
+
use_dp=True,
|
| 37 |
+
val_log_interval=10, # Unused, only kept for compatibility TODO: Remove
|
| 38 |
+
samples_per_speaker_number=3,
|
| 39 |
+
freeze_model1=False,
|
| 40 |
+
):
|
| 41 |
+
self.model = utils.import_attr(model)(**model_params)
|
| 42 |
+
|
| 43 |
+
self.use_dp = use_dp
|
| 44 |
+
if use_dp:
|
| 45 |
+
self.model = nn.DataParallel(self.model)
|
| 46 |
+
|
| 47 |
+
self.sr = sr
|
| 48 |
+
|
| 49 |
+
# Log a val sample every this many intervals
|
| 50 |
+
# self.val_log_interval = val_log_interval
|
| 51 |
+
self.samples_per_speaker_number = samples_per_speaker_number
|
| 52 |
+
|
| 53 |
+
# Initialize metrics
|
| 54 |
+
self.metrics = [Metrics(metric) for metric in metrics]
|
| 55 |
+
|
| 56 |
+
# Metric values
|
| 57 |
+
self.metric_values = {}
|
| 58 |
+
|
| 59 |
+
# Dataset statistics
|
| 60 |
+
self.statistics = {}
|
| 61 |
+
|
| 62 |
+
# Assine metric to monitor, and how to judge different models based on it
|
| 63 |
+
# i.e. How do we define the best model (Here, we minimize val loss)
|
| 64 |
+
self.monitor = "val/loss"
|
| 65 |
+
self.monitor_mode = "min"
|
| 66 |
+
|
| 67 |
+
# Mode, either train or val
|
| 68 |
+
self.mode = None
|
| 69 |
+
|
| 70 |
+
self.val_samples = {}
|
| 71 |
+
self.train_samples = {}
|
| 72 |
+
|
| 73 |
+
self.input_snr_calculated = False
|
| 74 |
+
self.input_snr = []
|
| 75 |
+
self.snr_metric = Metrics("snr")
|
| 76 |
+
|
| 77 |
+
# Initialize loss function
|
| 78 |
+
self.loss_fn = utils.import_attr(loss)(**loss_params)
|
| 79 |
+
|
| 80 |
+
# Initaize weights if checkpoint is provided
|
| 81 |
+
|
| 82 |
+
# prev ckpt is for the checkpoint of the complete joint model (fast+slow) you want to train from
|
| 83 |
+
if prev_ckpt is not None:
|
| 84 |
+
if prev_ckpt.endswith(".ckpt"):
|
| 85 |
+
print("load prev model", prev_ckpt)
|
| 86 |
+
state = torch.load(prev_ckpt)["state_dict"]
|
| 87 |
+
# print(state.keys())
|
| 88 |
+
print(state["current_epoch"])
|
| 89 |
+
if self.use_dp:
|
| 90 |
+
_model = self.model.module
|
| 91 |
+
else:
|
| 92 |
+
_model = self.model
|
| 93 |
+
|
| 94 |
+
mdl = FakeModel(_model)
|
| 95 |
+
mdl.load_state_dict(state)
|
| 96 |
+
self.model = nn.DataParallel(mdl.model)
|
| 97 |
+
else:
|
| 98 |
+
print("load prev model", prev_ckpt)
|
| 99 |
+
|
| 100 |
+
state = torch.load(prev_ckpt)
|
| 101 |
+
print(state["current_epoch"])
|
| 102 |
+
state = state["model"]
|
| 103 |
+
if self.use_dp:
|
| 104 |
+
self.model.module.load_state_dict(state)
|
| 105 |
+
else:
|
| 106 |
+
self.model.load_state_dict(state)
|
| 107 |
+
|
| 108 |
+
# init ckpt stands for the slow model's initial weights checkpoint path
|
| 109 |
+
elif slow_model_ckpt is not None:
|
| 110 |
+
print(f"Loading model 1 weights from checkpoint: {slow_model_ckpt}")
|
| 111 |
+
model1_ckpt = torch.load(slow_model_ckpt)
|
| 112 |
+
print("current epoch is {}".format(model1_ckpt["current_epoch"]))
|
| 113 |
+
|
| 114 |
+
model1_state_dict = {
|
| 115 |
+
key.replace("tce_model.", ""): value
|
| 116 |
+
for key, value in model1_ckpt["model"].items()
|
| 117 |
+
if key.startswith("tce_model.")
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if self.use_dp:
|
| 121 |
+
self.model.module.model1.load_state_dict(model1_state_dict, strict=False)
|
| 122 |
+
else:
|
| 123 |
+
self.model.model1.load_state_dict(model1_state_dict, strict=False)
|
| 124 |
+
|
| 125 |
+
else:
|
| 126 |
+
print("Loading model from scratch, no slow model init ckpt or joint model init ckpt")
|
| 127 |
+
|
| 128 |
+
# whether freeze slow model during training
|
| 129 |
+
self.freeze = freeze_model1
|
| 130 |
+
if freeze_model1:
|
| 131 |
+
self.freeze_model1()
|
| 132 |
+
params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters())
|
| 133 |
+
# Initialize optimizer
|
| 134 |
+
self.optimizer = utils.import_attr(optimizer)(params_to_optimize, **optimizer_params)
|
| 135 |
+
self.optim_name = optimizer
|
| 136 |
+
self.opt_params = optimizer_params
|
| 137 |
+
else:
|
| 138 |
+
# Initialize optimizer
|
| 139 |
+
self.optimizer = utils.import_attr(optimizer)(self.model.parameters(), **optimizer_params)
|
| 140 |
+
self.optim_name = optimizer
|
| 141 |
+
self.opt_params = optimizer_params
|
| 142 |
+
|
| 143 |
+
# Grad clip
|
| 144 |
+
self.grad_clip = grad_clip
|
| 145 |
+
|
| 146 |
+
if self.grad_clip is not None:
|
| 147 |
+
print(f"USING GRAD CLIP: {self.grad_clip}")
|
| 148 |
+
else:
|
| 149 |
+
print("ERROR! NOT USING GRAD CLIP" * 100)
|
| 150 |
+
|
| 151 |
+
# Initialize scheduler
|
| 152 |
+
self.scheduler = self.init_scheduler(scheduler, scheduler_params)
|
| 153 |
+
self.scheduler_name = scheduler
|
| 154 |
+
self.scheduler_params = scheduler_params
|
| 155 |
+
|
| 156 |
+
self.epoch = 0
|
| 157 |
+
|
| 158 |
+
def freeze_model1(self):
|
| 159 |
+
"""Freezes the weights of model1."""
|
| 160 |
+
print("Freezing model1 weights")
|
| 161 |
+
model1 = self.model.module.model1 if self.use_dp else self.model.model1
|
| 162 |
+
for param in model1.parameters():
|
| 163 |
+
param.requires_grad = False
|
| 164 |
+
print("Model1 weights frozen.")
|
| 165 |
+
|
| 166 |
+
def load_state(self, path, map_location=None):
|
| 167 |
+
state = torch.load(path, map_location=map_location)
|
| 168 |
+
|
| 169 |
+
if self.use_dp:
|
| 170 |
+
self.model.module.load_state_dict(state["model"])
|
| 171 |
+
else:
|
| 172 |
+
self.model.load_state_dict(state["model"])
|
| 173 |
+
|
| 174 |
+
# Re-initialize optimizer
|
| 175 |
+
if not self.freeze:
|
| 176 |
+
self.optimizer = utils.import_attr(self.optim_name)(self.model.parameters(), **self.opt_params)
|
| 177 |
+
else:
|
| 178 |
+
params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters())
|
| 179 |
+
self.optimizer = utils.import_attr(self.optim_name)(params_to_optimize, **self.opt_params)
|
| 180 |
+
|
| 181 |
+
# Re-initialize scheduler (Order might be important?)
|
| 182 |
+
if self.scheduler is not None:
|
| 183 |
+
self.scheduler = self.init_scheduler(self.scheduler_name, self.scheduler_params)
|
| 184 |
+
|
| 185 |
+
self.optimizer.load_state_dict(state["optimizer"])
|
| 186 |
+
|
| 187 |
+
if self.scheduler is not None:
|
| 188 |
+
self.scheduler.load_state_dict(state["scheduler"])
|
| 189 |
+
|
| 190 |
+
self.epoch = state["current_epoch"]
|
| 191 |
+
print("Load model from epoch", self.epoch)
|
| 192 |
+
self.metric_values = state["metric_values"]
|
| 193 |
+
|
| 194 |
+
if "statistics" in self.statistics:
|
| 195 |
+
self.statistics = state["statistics"]
|
| 196 |
+
|
| 197 |
+
def dump_state(self, path):
|
| 198 |
+
if self.use_dp:
|
| 199 |
+
_model = self.model.module
|
| 200 |
+
else:
|
| 201 |
+
_model = self.model
|
| 202 |
+
|
| 203 |
+
state = dict(
|
| 204 |
+
model=_model.state_dict(),
|
| 205 |
+
optimizer=self.optimizer.state_dict(),
|
| 206 |
+
current_epoch=self.epoch,
|
| 207 |
+
metric_values=self.metric_values,
|
| 208 |
+
statistics=self.statistics,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if self.scheduler is not None:
|
| 212 |
+
state["scheduler"] = self.scheduler.state_dict()
|
| 213 |
+
print("save to " + path)
|
| 214 |
+
torch.save(state, path)
|
| 215 |
+
|
| 216 |
+
def get_current_lr(self):
|
| 217 |
+
for param_group in self.optimizer.param_groups:
|
| 218 |
+
return param_group["lr"]
|
| 219 |
+
|
| 220 |
+
def on_epoch_start(self):
|
| 221 |
+
print()
|
| 222 |
+
print("=" * 25, "STARTING EPOCH", self.epoch, "=" * 25)
|
| 223 |
+
print()
|
| 224 |
+
|
| 225 |
+
def get_avg_metric_at_epoch(self, metric, epoch=None):
|
| 226 |
+
if epoch is None:
|
| 227 |
+
epoch = self.epoch
|
| 228 |
+
|
| 229 |
+
return self.metric_values[epoch][metric]["epoch"] / self.metric_values[epoch][metric]["num_elements"]
|
| 230 |
+
|
| 231 |
+
def on_epoch_end(self, best_path, wandb_run):
|
| 232 |
+
assert self.epoch + 1 == len(
|
| 233 |
+
self.metric_values
|
| 234 |
+
), "Current epoch must be equal to length of metrics (0-indexed)"
|
| 235 |
+
|
| 236 |
+
monitor_metric_last = self.get_avg_metric_at_epoch(self.monitor)
|
| 237 |
+
|
| 238 |
+
# Go over all epochs
|
| 239 |
+
save = True
|
| 240 |
+
for epoch in range(len(self.metric_values) - 1):
|
| 241 |
+
monitor_metric_at_epoch = self.get_avg_metric_at_epoch(self.monitor, epoch)
|
| 242 |
+
|
| 243 |
+
if self.monitor_mode == "max":
|
| 244 |
+
# If there is any model with monitor larger than current, then
|
| 245 |
+
# this is not the best model
|
| 246 |
+
if monitor_metric_last < monitor_metric_at_epoch:
|
| 247 |
+
save = False
|
| 248 |
+
break
|
| 249 |
+
|
| 250 |
+
if self.monitor_mode == "min":
|
| 251 |
+
# If there is any model with monitor smaller than current, then
|
| 252 |
+
# this is not the best model
|
| 253 |
+
if monitor_metric_last > monitor_metric_at_epoch:
|
| 254 |
+
save = False
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
# If this is best, save it
|
| 258 |
+
if save:
|
| 259 |
+
print("Current checkpoint is the best! Saving it...")
|
| 260 |
+
self.dump_state(best_path)
|
| 261 |
+
|
| 262 |
+
val_loss = self.get_avg_metric_at_epoch("val/loss")
|
| 263 |
+
val_snr_i = self.get_avg_metric_at_epoch("val/snr_i")
|
| 264 |
+
val_si_snr_i = self.get_avg_metric_at_epoch("val/si_snr_i")
|
| 265 |
+
|
| 266 |
+
print(f"Val loss: {val_loss:.02f}")
|
| 267 |
+
print(f"Val SNRi: {val_snr_i:.02f}dB")
|
| 268 |
+
print(f"Val SI-SDRi: {val_si_snr_i:.02f}dB")
|
| 269 |
+
|
| 270 |
+
# Log stuff on wandb
|
| 271 |
+
wandb_run.log({"lr-Adam": self.get_current_lr()}, commit=False, step=self.epoch + 1)
|
| 272 |
+
|
| 273 |
+
for metric in self.metric_values[self.epoch]:
|
| 274 |
+
wandb_run.log({metric: self.get_avg_metric_at_epoch(metric)}, commit=False, step=self.epoch + 1)
|
| 275 |
+
|
| 276 |
+
for statistic in self.statistics:
|
| 277 |
+
if not self.statistics[statistic]["logged"]:
|
| 278 |
+
data = self.statistics[statistic]["data"]
|
| 279 |
+
reduction = self.statistics[statistic]["reduction"]
|
| 280 |
+
if reduction == "mean":
|
| 281 |
+
val = mean(data)
|
| 282 |
+
elif reduction == "sum":
|
| 283 |
+
val = sum(data)
|
| 284 |
+
elif reduction == "histogram":
|
| 285 |
+
data = [[d] for d in data]
|
| 286 |
+
table = wandb.Table(data=data, columns=[statistic])
|
| 287 |
+
val = wandb.plot.histogram(table, statistic, title=statistic)
|
| 288 |
+
else:
|
| 289 |
+
assert 0, f"Unknown reduction {reduction}."
|
| 290 |
+
wandb_run.log({statistic: val}, commit=False)
|
| 291 |
+
self.statistics[statistic]["logged"] = True
|
| 292 |
+
|
| 293 |
+
wandb_run.log({"epoch": self.epoch}, commit=True, step=self.epoch + 1)
|
| 294 |
+
|
| 295 |
+
if self.scheduler is not None:
|
| 296 |
+
if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
|
| 297 |
+
# Get last metric
|
| 298 |
+
self.scheduler.step(monitor_metric_last)
|
| 299 |
+
else:
|
| 300 |
+
self.scheduler.step()
|
| 301 |
+
|
| 302 |
+
self.epoch += 1
|
| 303 |
+
|
| 304 |
+
def log_statistic(self, name, value, reduction="mean"):
|
| 305 |
+
if name not in self.statistics:
|
| 306 |
+
self.statistics[name] = dict(logged=False, data=[], reduction=reduction)
|
| 307 |
+
|
| 308 |
+
self.statistics[name]["data"].append(value)
|
| 309 |
+
|
| 310 |
+
def log_metric(self, name, value, batch_size=1, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True):
|
| 311 |
+
"""
|
| 312 |
+
Logs a metric
|
| 313 |
+
value must be the AVERAGE value across the batch
|
| 314 |
+
Must provide batch size for accurate average computation
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
epoch_str = self.epoch
|
| 318 |
+
if epoch_str not in self.metric_values:
|
| 319 |
+
self.metric_values[epoch_str] = {}
|
| 320 |
+
|
| 321 |
+
if name not in self.metric_values[epoch_str]:
|
| 322 |
+
self.metric_values[epoch_str][name] = dict(step=None, epoch=None)
|
| 323 |
+
|
| 324 |
+
if type(value) == torch.Tensor:
|
| 325 |
+
value = value.item()
|
| 326 |
+
|
| 327 |
+
if on_step:
|
| 328 |
+
if self.metric_values[epoch_str][name]["step"] is None:
|
| 329 |
+
self.metric_values[epoch_str][name]["step"] = []
|
| 330 |
+
|
| 331 |
+
self.metric_values[epoch_str][name]["step"].append(value)
|
| 332 |
+
|
| 333 |
+
if on_epoch:
|
| 334 |
+
if self.metric_values[epoch_str][name]["epoch"] is None:
|
| 335 |
+
self.metric_values[epoch_str][name]["epoch"] = 0
|
| 336 |
+
self.metric_values[epoch_str][name]["num_elements"] = 0
|
| 337 |
+
|
| 338 |
+
self.metric_values[epoch_str][name]["epoch"] += value * batch_size
|
| 339 |
+
self.metric_values[epoch_str][name]["num_elements"] += batch_size
|
| 340 |
+
|
| 341 |
+
def val_naive(self, batch, batch_idx):
|
| 342 |
+
inputs, targets = batch
|
| 343 |
+
a = torch.cuda.memory_allocated(inputs["mixture"].device)
|
| 344 |
+
outputs = self.model(inputs)
|
| 345 |
+
b = torch.cuda.memory_allocated(inputs["mixture"].device)
|
| 346 |
+
print("Infer consume M", (b - a) / 1e6)
|
| 347 |
+
|
| 348 |
+
return outputs
|
| 349 |
+
|
| 350 |
+
def train_naive(self, batch, batch_idx):
|
| 351 |
+
self.reset_grad()
|
| 352 |
+
inputs, targets = batch
|
| 353 |
+
a = torch.cuda.memory_allocated(inputs["mixture"].device)
|
| 354 |
+
# print("a", a/1e9 )
|
| 355 |
+
outputs = self.model(inputs)
|
| 356 |
+
|
| 357 |
+
est = outputs["output"]
|
| 358 |
+
gt = targets["target"]
|
| 359 |
+
|
| 360 |
+
# Compute loss
|
| 361 |
+
loss = self.loss_fn(est=est, gt=gt).mean()
|
| 362 |
+
b = torch.cuda.memory_allocated(inputs["mixture"].device)
|
| 363 |
+
|
| 364 |
+
loss.backward(retain_graph=True)
|
| 365 |
+
c = torch.cuda.memory_allocated(inputs["mixture"].device)
|
| 366 |
+
|
| 367 |
+
self.backprop()
|
| 368 |
+
d = torch.cuda.memory_allocated(inputs["mixture"].device)
|
| 369 |
+
|
| 370 |
+
print("Training consume G", (b - a) / 1e9, (c - a) / 1e9, (d - c) / 1e9, a / 1e9)
|
| 371 |
+
return outputs
|
| 372 |
+
|
| 373 |
+
def silence_audio(self, input, timestamp):
|
| 374 |
+
output_audio = input.clone()
|
| 375 |
+
for start, end in timestamp:
|
| 376 |
+
output_audio[start:end] = 0.0
|
| 377 |
+
|
| 378 |
+
return output_audio
|
| 379 |
+
|
| 380 |
+
def _step(self, batch, batch_idx, step="train"):
|
| 381 |
+
inputs, targets = batch
|
| 382 |
+
batch_size = inputs["mixture"].shape[0]
|
| 383 |
+
|
| 384 |
+
start_idx = inputs["start_idx_list"][0].item()
|
| 385 |
+
end_idx = inputs["end_idx_list"][0].item()
|
| 386 |
+
inputs["start_idx"] = start_idx
|
| 387 |
+
inputs["end_idx"] = end_idx
|
| 388 |
+
|
| 389 |
+
outputs = self.model(inputs)
|
| 390 |
+
est = outputs["output"].clone()
|
| 391 |
+
|
| 392 |
+
if "audio_range" in outputs:
|
| 393 |
+
audio_range = outputs["audio_range"]
|
| 394 |
+
start_indices = audio_range[:, 0] # Shape: [batch]
|
| 395 |
+
end_indices = audio_range[:, 1]
|
| 396 |
+
sliced_gt = []
|
| 397 |
+
sliced_mix = []
|
| 398 |
+
sliced_self = []
|
| 399 |
+
# masked_est_list=[]
|
| 400 |
+
|
| 401 |
+
gt_clone = targets["target"].clone()
|
| 402 |
+
mix_clone = inputs["mixture"][:, 0:1].clone()
|
| 403 |
+
full_self_speech_clone = inputs["self_speech"].clone()
|
| 404 |
+
|
| 405 |
+
for index in range(est.size(0)):
|
| 406 |
+
start = start_indices[index].item()
|
| 407 |
+
end = end_indices[index].item()
|
| 408 |
+
|
| 409 |
+
sliced_gt.append(gt_clone[index, :, start:end])
|
| 410 |
+
sliced_mix.append(mix_clone[index, :, start:end])
|
| 411 |
+
sliced_self.append(full_self_speech_clone[index, :, start:end])
|
| 412 |
+
|
| 413 |
+
# Stack the sliced audio to form the final tensor
|
| 414 |
+
gt = torch.stack(sliced_gt, dim=0)
|
| 415 |
+
mix = torch.stack(sliced_mix, dim=0)
|
| 416 |
+
self_speech_final = torch.stack(sliced_self, dim=0)
|
| 417 |
+
|
| 418 |
+
else:
|
| 419 |
+
mix = inputs["mixture"][:, 0:1].clone()
|
| 420 |
+
gt = targets["target"].clone()
|
| 421 |
+
self_speech_final = targets["self_speech"].clone()
|
| 422 |
+
|
| 423 |
+
# Compute loss
|
| 424 |
+
loss = self.loss_fn(est=est, gt=gt).mean()
|
| 425 |
+
|
| 426 |
+
est_detached = est.detach().clone()
|
| 427 |
+
|
| 428 |
+
with torch.no_grad():
|
| 429 |
+
# Log loss
|
| 430 |
+
self.log_metric(
|
| 431 |
+
f"{step}/loss",
|
| 432 |
+
loss.item(),
|
| 433 |
+
batch_size=batch_size,
|
| 434 |
+
on_step=(step == "train"),
|
| 435 |
+
on_epoch=True,
|
| 436 |
+
prog_bar=True,
|
| 437 |
+
sync_dist=True,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Log metrics
|
| 441 |
+
for metric in self.metrics:
|
| 442 |
+
if step == "train" and (metric.name == "PESQ" or metric.name == "STOI"):
|
| 443 |
+
continue
|
| 444 |
+
metric_val = metric(est=est_detached, gt=gt, mix=mix, self_speech=self_speech_final)
|
| 445 |
+
for i in range(batch_size):
|
| 446 |
+
# if gt is all zero, cannot compute metric
|
| 447 |
+
if torch.all(gt[i] == 0):
|
| 448 |
+
# print(f"Skipping sample {i} in batch because gt is all zeros.")
|
| 449 |
+
continue
|
| 450 |
+
val = metric_val[i].item()
|
| 451 |
+
self.log_metric(
|
| 452 |
+
f"{step}/{metric.name}",
|
| 453 |
+
val,
|
| 454 |
+
batch_size=1,
|
| 455 |
+
on_step=False,
|
| 456 |
+
on_epoch=True,
|
| 457 |
+
prog_bar=True,
|
| 458 |
+
sync_dist=True,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Create collection of things to show in a sample on wandb
|
| 462 |
+
sample = {
|
| 463 |
+
"mixture": mix,
|
| 464 |
+
"output": est_detached,
|
| 465 |
+
"target": gt,
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
return loss, sample
|
| 469 |
+
|
| 470 |
+
def train(self):
|
| 471 |
+
self.model.train()
|
| 472 |
+
self.mode = "train"
|
| 473 |
+
|
| 474 |
+
def eval(self):
|
| 475 |
+
self.model.eval()
|
| 476 |
+
self.mode = "val"
|
| 477 |
+
|
| 478 |
+
def training_step(self, batch, batch_idx):
|
| 479 |
+
loss, sample = self._step(batch, batch_idx, step="train")
|
| 480 |
+
|
| 481 |
+
target = sample["target"]
|
| 482 |
+
|
| 483 |
+
return loss, target.shape[0]
|
| 484 |
+
|
| 485 |
+
def validation_step(self, batch, batch_idx):
|
| 486 |
+
loss, sample = self._step(batch, batch_idx, step="val")
|
| 487 |
+
|
| 488 |
+
target = sample["target"]
|
| 489 |
+
|
| 490 |
+
return loss, target.shape[0]
|
| 491 |
+
|
| 492 |
+
def reset_grad(self):
|
| 493 |
+
self.optimizer.zero_grad()
|
| 494 |
+
|
| 495 |
+
def backprop(self):
|
| 496 |
+
# print("BACKPROP")
|
| 497 |
+
# print(self.grad_clip)
|
| 498 |
+
# Gradient clipping
|
| 499 |
+
if self.grad_clip is not None:
|
| 500 |
+
# print("Clipping grad norm")
|
| 501 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
|
| 502 |
+
|
| 503 |
+
self.optimizer.step()
|
| 504 |
+
|
| 505 |
+
def configure_optimizers(self):
|
| 506 |
+
if self.scheduler is not None:
|
| 507 |
+
# For reduce LR on plateau, we need to provide more information
|
| 508 |
+
if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
|
| 509 |
+
scheduler_cfg = {
|
| 510 |
+
"scheduler": self.scheduler,
|
| 511 |
+
"interval": "epoch",
|
| 512 |
+
"frequency": 1,
|
| 513 |
+
"monitor": self.monitor,
|
| 514 |
+
"strict": False,
|
| 515 |
+
}
|
| 516 |
+
else:
|
| 517 |
+
scheduler_cfg = self.scheduler
|
| 518 |
+
return [self.optimizer], [scheduler_cfg]
|
| 519 |
+
else:
|
| 520 |
+
return self.optimizer
|
| 521 |
+
|
| 522 |
+
def init_scheduler(self, scheduler, scheduler_params):
|
| 523 |
+
if scheduler is not None:
|
| 524 |
+
if scheduler == "sequential":
|
| 525 |
+
schedulers = []
|
| 526 |
+
milestones = []
|
| 527 |
+
for scheduler_param in scheduler_params:
|
| 528 |
+
sched = utils.import_attr(scheduler_param["name"])(self.optimizer, **scheduler_param["params"])
|
| 529 |
+
schedulers.append(sched)
|
| 530 |
+
milestones.append(scheduler_param["epochs"])
|
| 531 |
+
|
| 532 |
+
# Cumulative sum for milestones
|
| 533 |
+
for i in range(1, len(milestones)):
|
| 534 |
+
milestones[i] = milestones[i - 1] + milestones[i]
|
| 535 |
+
|
| 536 |
+
# Remove last milestone as it is implied by num epochs
|
| 537 |
+
milestones.pop()
|
| 538 |
+
|
| 539 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(self.optimizer, schedulers, milestones)
|
| 540 |
+
else:
|
| 541 |
+
scheduler = utils.import_attr(scheduler)(self.optimizer, **scheduler_params)
|
| 542 |
+
|
| 543 |
+
return scheduler
|
src/losses/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/losses/SNRLP.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
from src.losses.SNRLosses import SNRLosses
|
| 6 |
+
from src.losses.LogPowerLoss import LogPowerLoss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SNRLPLoss(nn.Module):
|
| 10 |
+
def __init__(self, snr_loss_name = "snr", neg_weight = 1) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.snr_loss = SNRLosses(snr_loss_name)
|
| 13 |
+
#self.lp_loss = LogPowerLoss()
|
| 14 |
+
self.lp_loss = nn.L1Loss()#LogPowerLoss()
|
| 15 |
+
self.neg_weight = neg_weight
|
| 16 |
+
|
| 17 |
+
def forward(self, est: torch.Tensor, gt: torch.Tensor, **kwargs):
|
| 18 |
+
"""
|
| 19 |
+
input: (B, C, T) (B, C, T)
|
| 20 |
+
"""
|
| 21 |
+
# print(est.shape, gt.shape)
|
| 22 |
+
neg_loss = 0
|
| 23 |
+
pos_loss = 0
|
| 24 |
+
|
| 25 |
+
comp_loss = torch.zeros((est.shape[0]), device=est.device)
|
| 26 |
+
mask = (torch.max(torch.max(torch.abs(gt), dim=2)[0], dim=1)[0] == 0)
|
| 27 |
+
#print("mask", mask)
|
| 28 |
+
# If there's at least one negative sample
|
| 29 |
+
if any(mask):
|
| 30 |
+
est_neg, gt_neg = est[mask], gt[mask]
|
| 31 |
+
neg_loss = self.lp_loss(est_neg, gt_neg)
|
| 32 |
+
comp_loss[mask] = neg_loss * self.neg_weight
|
| 33 |
+
|
| 34 |
+
# If there's at least one positive sample
|
| 35 |
+
if any((~ mask)):
|
| 36 |
+
est_pos, gt_pos = est[~mask], gt[~mask]
|
| 37 |
+
pos_loss = self.snr_loss(est_pos, gt_pos)
|
| 38 |
+
|
| 39 |
+
# Compute_joint_loss
|
| 40 |
+
comp_loss[~mask] = pos_loss
|
| 41 |
+
|
| 42 |
+
return comp_loss
|
src/metrics/metrics.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torchaudio.functional import resample
|
| 5 |
+
|
| 6 |
+
from torchmetrics.functional import(
|
| 7 |
+
scale_invariant_signal_distortion_ratio as si_sdr,
|
| 8 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
| 9 |
+
signal_noise_ratio as snr)
|
| 10 |
+
|
| 11 |
+
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility as STOI
|
| 12 |
+
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality as PESQ
|
| 13 |
+
import numpy as np
|
| 14 |
+
import copy
|
| 15 |
+
from src.losses.MultiResoLoss import MultiResoFuseLoss
|
| 16 |
+
from src.losses.Perceptual_Loss import PLCPALoss
|
| 17 |
+
|
| 18 |
+
def compute_decay(est, mix):
|
| 19 |
+
"""
|
| 20 |
+
[*, C, T]
|
| 21 |
+
"""
|
| 22 |
+
types = type(est)
|
| 23 |
+
assert type(mix) == types, "All arrays must be the same type"
|
| 24 |
+
if types == np.ndarray:
|
| 25 |
+
est, mix = torch.from_numpy(est), torch.from_numpy(mix)
|
| 26 |
+
|
| 27 |
+
# Ensure that, no matter what, we do not modify the original arrays
|
| 28 |
+
est = est.clone()
|
| 29 |
+
mix = mix.clone()
|
| 30 |
+
|
| 31 |
+
P_est = 10 * torch.log10(torch.sum(est ** 2, dim=-1)) # [*, C]
|
| 32 |
+
P_mix = 10 * torch.log10(torch.sum(mix ** 2, dim=-1))
|
| 33 |
+
|
| 34 |
+
return (P_mix - P_est).mean(dim=-1) # [*]
|
| 35 |
+
|
| 36 |
+
class Metrics(nn.Module):
|
| 37 |
+
def __init__(self, name, fs = 24000, **kwargs) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.fs = fs
|
| 40 |
+
self.func = None
|
| 41 |
+
self.name=name
|
| 42 |
+
if name == 'snr':
|
| 43 |
+
self.func = lambda est, gt, mix, self_speech: snr(preds=est, target=gt)
|
| 44 |
+
elif name == 'snr_i':
|
| 45 |
+
self.func = lambda est, gt, mix, self_speech: snr(preds=est, target=gt) - snr(preds=mix, target=gt)
|
| 46 |
+
elif name == 'si_snr':
|
| 47 |
+
self.func = lambda est, gt, mix, self_speech: si_snr(preds=est, target=gt)
|
| 48 |
+
elif name == 'si_snr_i':
|
| 49 |
+
self.func = lambda est, gt, mix, self_speech: si_snr(preds=est, target=gt) - si_snr(preds=mix, target=gt)
|
| 50 |
+
elif name == 'si_sdr':
|
| 51 |
+
self.func = lambda est, gt, mix, self_speech: si_sdr(preds=est, target=gt)
|
| 52 |
+
elif name == 'si_sdr_i':
|
| 53 |
+
self.func = lambda est, gt, mix, self_speech: si_sdr(preds=est, target=gt) - si_sdr(preds=mix, target=gt)
|
| 54 |
+
elif name == 'si_sdr_i_adj':
|
| 55 |
+
self.func = lambda est, gt, mix, self_speech: si_sdr(preds=est, target=gt) - si_sdr(preds=mix, target=gt+self_speech)
|
| 56 |
+
elif name == 'STOI':
|
| 57 |
+
self.func = lambda est, gt, mix, self_speech: STOI(preds=est, target=gt, fs=fs)
|
| 58 |
+
elif name == 'PESQ':
|
| 59 |
+
fs_new = 16000
|
| 60 |
+
self.func = lambda est, gt, mix, self_speech: PESQ(preds=resample(est, fs, fs_new), target=resample(gt, fs, fs_new), fs=fs_new, mode = "nb")
|
| 61 |
+
elif name == 'Multi_Reso_L1':
|
| 62 |
+
mult_ireso_loss = MultiResoFuseLoss(**kwargs)
|
| 63 |
+
self.func = lambda est, gt, mix, self_speech: mult_ireso_loss(est = est, gt = gt)
|
| 64 |
+
elif name == 'PLCPALoss':
|
| 65 |
+
plcpa = PLCPALoss(**kwargs)
|
| 66 |
+
self.func = lambda est, gt, mix, self_speech: plcpa(est = est, gt = gt)
|
| 67 |
+
else:
|
| 68 |
+
raise NotImplementedError(f"Metric {name} not implemented!")
|
| 69 |
+
|
| 70 |
+
def forward(self, est, gt, mix, self_speech=None):
|
| 71 |
+
"""
|
| 72 |
+
input: (*, C, T)
|
| 73 |
+
output: (*)
|
| 74 |
+
"""
|
| 75 |
+
types = type(est)
|
| 76 |
+
assert type(gt) == types and type(mix) == types, "All arrays must be the same type"
|
| 77 |
+
if types == np.ndarray:
|
| 78 |
+
est, gt, mix = torch.from_numpy(est), torch.from_numpy(gt), torch.from_numpy(mix)
|
| 79 |
+
|
| 80 |
+
# Ensure that, no matter what, we do not modify the original arrays
|
| 81 |
+
est = est.clone()
|
| 82 |
+
gt = gt.clone()
|
| 83 |
+
mix = mix.clone()
|
| 84 |
+
|
| 85 |
+
if self_speech is not None:
|
| 86 |
+
if type(self_speech)==np.ndarray:
|
| 87 |
+
self_speech=torch.from_numpy(self_speech)
|
| 88 |
+
self_speech=self_speech.clone()
|
| 89 |
+
|
| 90 |
+
# print("shape of est in metrics is {}".format(est.shape)) [1, 1, 160000]
|
| 91 |
+
# print("shape of gt is {}".format(gt.shape))
|
| 92 |
+
# print("mix has shape {}".format(mix.shape))
|
| 93 |
+
|
| 94 |
+
# per_channel_metrics = self.func(est=est, gt=gt, mix=mix) # [*, C]
|
| 95 |
+
per_channel_metrics = self.func(est=est, gt=gt, mix=mix, self_speech=self_speech) # [*, C]
|
| 96 |
+
|
| 97 |
+
if self.name == "PLCPALoss":
|
| 98 |
+
return per_channel_metrics[0].mean(dim=-1), per_channel_metrics[1].mean(dim=-1), per_channel_metrics[2].mean(dim=-1)
|
| 99 |
+
else:
|
| 100 |
+
return per_channel_metrics.mean(dim=-1) # [*]
|
src/models/blocks/model1_block.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from espnet2.torch_utils.get_layer_from_string import get_layer
|
| 10 |
+
from torch.nn import init
|
| 11 |
+
from torch.nn.parameter import Parameter
|
| 12 |
+
import src.utils as utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Lambda(nn.Module):
|
| 16 |
+
def __init__(self, lambd):
|
| 17 |
+
super().__init__()
|
| 18 |
+
import types
|
| 19 |
+
|
| 20 |
+
assert type(lambd) is types.LambdaType
|
| 21 |
+
self.lambd = lambd
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.lambd(x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LayerNormPermuted(nn.LayerNorm):
|
| 28 |
+
def __init__(self, *args, **kwargs):
|
| 29 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
x: [B, C, T, F]
|
| 35 |
+
"""
|
| 36 |
+
x = x.permute(0, 2, 3, 1) # [B, T, F, C]
|
| 37 |
+
x = super().forward(x)
|
| 38 |
+
x = x.permute(0, 3, 1, 2) # [B, C, T, F]
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Use native layernorm implementation
|
| 43 |
+
class LayerNormalization4D(nn.Module):
|
| 44 |
+
def __init__(self, C, eps=1e-5, preserve_outdim=False):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.norm = nn.LayerNorm(C, eps=eps)
|
| 47 |
+
self.preserve_outdim = preserve_outdim
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor):
|
| 50 |
+
"""
|
| 51 |
+
input: (*, C)
|
| 52 |
+
"""
|
| 53 |
+
x = self.norm(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LayerNormalization4DCF(nn.Module):
|
| 58 |
+
def __init__(self, input_dimension, eps=1e-5):
|
| 59 |
+
assert len(input_dimension) == 2
|
| 60 |
+
Q, C = input_dimension
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.norm = nn.LayerNorm((Q * C), eps=eps)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor):
|
| 65 |
+
"""
|
| 66 |
+
input: (B, T, Q * C)
|
| 67 |
+
"""
|
| 68 |
+
x = self.norm(x)
|
| 69 |
+
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LayerNormalization4D_old(nn.Module):
|
| 74 |
+
def __init__(self, input_dimension, eps=1e-5):
|
| 75 |
+
super().__init__()
|
| 76 |
+
param_size = [1, input_dimension, 1, 1]
|
| 77 |
+
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
| 78 |
+
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
| 79 |
+
init.ones_(self.gamma)
|
| 80 |
+
init.zeros_(self.beta)
|
| 81 |
+
self.eps = eps
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if x.ndim == 4:
|
| 85 |
+
_, C, _, _ = x.shape
|
| 86 |
+
stat_dim = (1,)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
|
| 89 |
+
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
|
| 90 |
+
std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F]
|
| 91 |
+
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
| 92 |
+
return x_hat
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def mod_pad(x, chunk_size, pad):
|
| 96 |
+
# Mod pad the rminput to perform integer number of
|
| 97 |
+
# inferences
|
| 98 |
+
mod = 0
|
| 99 |
+
if (x.shape[-1] % chunk_size) != 0:
|
| 100 |
+
mod = chunk_size - (x.shape[-1] % chunk_size)
|
| 101 |
+
|
| 102 |
+
x = F.pad(x, (0, mod))
|
| 103 |
+
x = F.pad(x, pad)
|
| 104 |
+
|
| 105 |
+
return x, mod
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Attention_STFT_causal(nn.Module):
|
| 109 |
+
def __getitem__(self, key):
|
| 110 |
+
return getattr(self, key)
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
emb_dim,
|
| 115 |
+
n_freqs,
|
| 116 |
+
approx_qk_dim=512,
|
| 117 |
+
n_head=4,
|
| 118 |
+
activation="prelu",
|
| 119 |
+
eps=1e-5,
|
| 120 |
+
skip_conn=True,
|
| 121 |
+
use_flash_attention=False,
|
| 122 |
+
dim_feedforward=-1,
|
| 123 |
+
local_context_len=-1,
|
| 124 |
+
# 6
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000)
|
| 128 |
+
|
| 129 |
+
self.skip_conn = skip_conn
|
| 130 |
+
self.n_freqs = n_freqs
|
| 131 |
+
self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
|
| 132 |
+
self.n_head = n_head
|
| 133 |
+
self.V_dim = emb_dim // n_head
|
| 134 |
+
self.emb_dim = emb_dim
|
| 135 |
+
assert emb_dim % n_head == 0
|
| 136 |
+
E = self.E
|
| 137 |
+
|
| 138 |
+
self.use_flash_attention = use_flash_attention
|
| 139 |
+
|
| 140 |
+
self.local_context_len = local_context_len
|
| 141 |
+
|
| 142 |
+
self.add_module(
|
| 143 |
+
"attn_conv_Q",
|
| 144 |
+
nn.Sequential(
|
| 145 |
+
nn.Linear(emb_dim, E * n_head), # [B, T, Q, HE]
|
| 146 |
+
get_layer(activation)(),
|
| 147 |
+
# [B, T, Q, H, E] -> [B, H, T, Q, E] -> [B * H, T, Q * E]
|
| 148 |
+
Lambda(
|
| 149 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
|
| 150 |
+
.permute(0, 3, 1, 2, 4)
|
| 151 |
+
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
|
| 152 |
+
), # (BH, T, Q * E)
|
| 153 |
+
LayerNormalization4DCF((n_freqs, E), eps=eps),
|
| 154 |
+
),
|
| 155 |
+
)
|
| 156 |
+
self.add_module(
|
| 157 |
+
"attn_conv_K",
|
| 158 |
+
nn.Sequential(
|
| 159 |
+
nn.Linear(emb_dim, E * n_head),
|
| 160 |
+
get_layer(activation)(),
|
| 161 |
+
Lambda(
|
| 162 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
|
| 163 |
+
.permute(0, 3, 1, 2, 4)
|
| 164 |
+
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
|
| 165 |
+
),
|
| 166 |
+
LayerNormalization4DCF((n_freqs, E), eps=eps),
|
| 167 |
+
),
|
| 168 |
+
)
|
| 169 |
+
self.add_module(
|
| 170 |
+
"attn_conv_V",
|
| 171 |
+
nn.Sequential(
|
| 172 |
+
nn.Linear(emb_dim, (emb_dim // n_head) * n_head),
|
| 173 |
+
get_layer(activation)(),
|
| 174 |
+
Lambda(
|
| 175 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head))
|
| 176 |
+
.permute(0, 3, 1, 2, 4)
|
| 177 |
+
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head))
|
| 178 |
+
),
|
| 179 |
+
LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps),
|
| 180 |
+
),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.dim_feedforward = dim_feedforward
|
| 184 |
+
|
| 185 |
+
if dim_feedforward == -1:
|
| 186 |
+
self.add_module(
|
| 187 |
+
"attn_concat_proj",
|
| 188 |
+
nn.Sequential(
|
| 189 |
+
nn.Linear(emb_dim, emb_dim),
|
| 190 |
+
get_layer(activation)(),
|
| 191 |
+
Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])),
|
| 192 |
+
LayerNormalization4DCF((n_freqs, emb_dim), eps=eps),
|
| 193 |
+
),
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
self.linear1 = nn.Linear(emb_dim, dim_feedforward)
|
| 197 |
+
self.dropout = nn.Dropout(p=0.1)
|
| 198 |
+
self.activation = nn.ReLU()
|
| 199 |
+
self.linear2 = nn.Linear(dim_feedforward, emb_dim)
|
| 200 |
+
self.dropout2 = nn.Dropout(p=0.1)
|
| 201 |
+
self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps)
|
| 202 |
+
|
| 203 |
+
def _ff_block(self, x):
|
| 204 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 205 |
+
return self.dropout2(x)
|
| 206 |
+
|
| 207 |
+
def get_lookahead_mask(self, seq_len, device):
|
| 208 |
+
|
| 209 |
+
if self.local_context_len == -1:
|
| 210 |
+
mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
|
| 211 |
+
|
| 212 |
+
return mask.detach().to(device)
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
mask1 = torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1
|
| 216 |
+
mask2 = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=self.local_context_len) == 0
|
| 217 |
+
mask = (mask1 * mask2).transpose(0, 1)
|
| 218 |
+
|
| 219 |
+
return mask.detach().to(device)
|
| 220 |
+
|
| 221 |
+
def forward(self, batch):
|
| 222 |
+
### input/output B T F C
|
| 223 |
+
# attention
|
| 224 |
+
inputs = batch
|
| 225 |
+
B0, T0, Q0, C0 = batch.shape
|
| 226 |
+
|
| 227 |
+
# positional encoding
|
| 228 |
+
pos_code = self.position_code(batch) # 1, T, embed_dim
|
| 229 |
+
_, T, QC = pos_code.shape
|
| 230 |
+
pos_code = pos_code.reshape(1, T, Q0, C0)
|
| 231 |
+
batch = batch + pos_code
|
| 232 |
+
|
| 233 |
+
Q = self["attn_conv_Q"](batch) # [B', T, Q * C]
|
| 234 |
+
K = self["attn_conv_K"](batch) # [B', T, Q * C]
|
| 235 |
+
V = self["attn_conv_V"](batch) # [B', T, Q * C]
|
| 236 |
+
|
| 237 |
+
emb_dim = Q.shape[-1]
|
| 238 |
+
|
| 239 |
+
local_mask = self.get_lookahead_mask(batch.shape[1], batch.device)
|
| 240 |
+
|
| 241 |
+
attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
|
| 242 |
+
attn_mat.masked_fill_(local_mask == 0, -float("Inf"))
|
| 243 |
+
attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
|
| 244 |
+
|
| 245 |
+
V = torch.matmul(attn_mat, V) # [B', T, Q*C]
|
| 246 |
+
V = V.reshape(-1, T0, V.shape[-1]) # [BH, T, Q * C]
|
| 247 |
+
V = V.transpose(1, 2) # [B', Q * C, T]
|
| 248 |
+
|
| 249 |
+
batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) # [B, H, Q, C, T]
|
| 250 |
+
batch = batch.transpose(2, 3) # [B, H, C, Q, T]
|
| 251 |
+
batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) # [B, HC, Q, T]
|
| 252 |
+
batch = batch.permute(0, 3, 2, 1) # [B, T, Q, C]
|
| 253 |
+
|
| 254 |
+
if self.dim_feedforward == -1:
|
| 255 |
+
batch = self["attn_concat_proj"](batch) # [B, T, Q * C]
|
| 256 |
+
else:
|
| 257 |
+
batch = batch + self._ff_block(batch) # [B, T, Q, C]
|
| 258 |
+
batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3])
|
| 259 |
+
batch = self.norm(batch)
|
| 260 |
+
batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) # [B, T, Q, C])
|
| 261 |
+
|
| 262 |
+
# Add batch if attention is performed
|
| 263 |
+
if self.skip_conn:
|
| 264 |
+
return batch + inputs
|
| 265 |
+
else:
|
| 266 |
+
return batch
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class GridNetBlock(nn.Module):
|
| 270 |
+
def __getitem__(self, key):
|
| 271 |
+
return getattr(self, key)
|
| 272 |
+
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
emb_dim,
|
| 276 |
+
emb_ks,
|
| 277 |
+
emb_hs,
|
| 278 |
+
n_freqs,
|
| 279 |
+
hidden_channels,
|
| 280 |
+
lstm_fold_chunk,
|
| 281 |
+
n_head=4,
|
| 282 |
+
approx_qk_dim=512,
|
| 283 |
+
activation="prelu",
|
| 284 |
+
eps=1e-5,
|
| 285 |
+
pool="mean",
|
| 286 |
+
last=False,
|
| 287 |
+
local_context_len=-1,
|
| 288 |
+
# 6
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
bidirectional = True # bidirectional within the intra frame lstm
|
| 292 |
+
|
| 293 |
+
self.global_atten_causal = True
|
| 294 |
+
|
| 295 |
+
self.last = last
|
| 296 |
+
|
| 297 |
+
self.pool = pool
|
| 298 |
+
|
| 299 |
+
self.lstm_fold_chunk = lstm_fold_chunk
|
| 300 |
+
self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
|
| 301 |
+
|
| 302 |
+
self.V_dim = emb_dim // n_head
|
| 303 |
+
self.H = hidden_channels
|
| 304 |
+
in_channels = emb_dim * emb_ks
|
| 305 |
+
self.in_channels = in_channels
|
| 306 |
+
self.n_freqs = n_freqs
|
| 307 |
+
|
| 308 |
+
## intra RNN can be optimized by conv or linear because the frequence length are not very large
|
| 309 |
+
self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps)
|
| 310 |
+
self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True)
|
| 311 |
+
self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs)
|
| 312 |
+
self.emb_dim = emb_dim
|
| 313 |
+
self.emb_ks = emb_ks
|
| 314 |
+
self.emb_hs = emb_hs
|
| 315 |
+
|
| 316 |
+
# inter RNN
|
| 317 |
+
self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps)
|
| 318 |
+
self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional)
|
| 319 |
+
self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs)
|
| 320 |
+
|
| 321 |
+
# attention
|
| 322 |
+
self.pool_atten_causal = Attention_STFT_causal(
|
| 323 |
+
emb_dim=emb_dim,
|
| 324 |
+
n_freqs=n_freqs,
|
| 325 |
+
approx_qk_dim=approx_qk_dim,
|
| 326 |
+
n_head=n_head,
|
| 327 |
+
activation=activation,
|
| 328 |
+
eps=eps,
|
| 329 |
+
local_context_len=local_context_len,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def _unfold_timedomain(self, x):
|
| 333 |
+
BQ, C, T = x.shape
|
| 334 |
+
x = torch.split(x, self.lstm_fold_chunk, dim=-1) # [Num_chunk, BQ, C, 100]
|
| 335 |
+
x = torch.cat(x, dim=0).reshape(-1, BQ, C, self.lstm_fold_chunk) # [Num_chunk, BQ, C, 100]
|
| 336 |
+
x = x.permute(1, 0, 3, 2) # [BQ, Num_chunk, 100, C]
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
def forward(self, x, init_state=None):
|
| 340 |
+
"""GridNetBlock Forward.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
x: [B, C, T, Q]
|
| 344 |
+
out: [B, C, T, Q]
|
| 345 |
+
"""
|
| 346 |
+
B, C, old_T, old_Q = x.shape
|
| 347 |
+
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
| 348 |
+
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
| 349 |
+
x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
|
| 350 |
+
|
| 351 |
+
# ===========================Intra RNN start================================
|
| 352 |
+
# define intra RNN
|
| 353 |
+
input_ = x
|
| 354 |
+
intra_rnn = self.intra_norm(input_) # [B, C, T, Q]
|
| 355 |
+
intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
|
| 356 |
+
|
| 357 |
+
intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) # [Q/I, BT, C, I]
|
| 358 |
+
intra_rnn = torch.stack(intra_rnn, dim=0)
|
| 359 |
+
intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) # [BT, CI, Q/I]
|
| 360 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, nC*emb_ks]
|
| 361 |
+
self.intra_rnn.flatten_parameters()
|
| 362 |
+
|
| 363 |
+
# apply intra frame LSTM
|
| 364 |
+
intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
|
| 365 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
|
| 366 |
+
intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
|
| 367 |
+
intra_rnn = intra_rnn.view([B, T, C, Q])
|
| 368 |
+
intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q]
|
| 369 |
+
intra_rnn = intra_rnn + input_ # [B, C, T, Q]
|
| 370 |
+
intra_rnn = intra_rnn[:, :, :, :old_Q] # [B, C, T, Q]
|
| 371 |
+
Q = old_Q
|
| 372 |
+
# ===========================Intra RNN end================================
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# ===========================Inter RNN start================================
|
| 376 |
+
# fold the time domain to chunk
|
| 377 |
+
inter_rnn = self.inter_norm(intra_rnn) # [B, C, T, F]
|
| 378 |
+
inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) # [BF, C, T]
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
inter_rnn = self._unfold_timedomain(inter_rnn) ### BQ, NUM_CHUNK, CHUNK_SIZE, C
|
| 382 |
+
|
| 383 |
+
BQ, NUM_CHUNK, CHUNKSIZE, C = inter_rnn.shape
|
| 384 |
+
|
| 385 |
+
inter_rnn = inter_rnn.reshape(BQ * NUM_CHUNK, CHUNKSIZE, C) ### BQ* NUM_CHUNK, CHUNK_SIZE, C
|
| 386 |
+
inter_rnn = inter_rnn.transpose(2, 1) # [B, C, T]
|
| 387 |
+
input_ = inter_rnn
|
| 388 |
+
|
| 389 |
+
inter_rnn = torch.split(inter_rnn, self.emb_ks, dim=-1)
|
| 390 |
+
|
| 391 |
+
inter_rnn = torch.stack(inter_rnn, dim=0)
|
| 392 |
+
inter_rnn = inter_rnn.permute(1, 2, 3, 0)
|
| 393 |
+
|
| 394 |
+
BF, C, EO, _T = inter_rnn.shape
|
| 395 |
+
inter_rnn = inter_rnn.reshape(BF, C * EO, _T)
|
| 396 |
+
|
| 397 |
+
inter_rnn = inter_rnn.transpose(1, 2)
|
| 398 |
+
|
| 399 |
+
self.inter_rnn.flatten_parameters()
|
| 400 |
+
inter_rnn, _ = self.inter_rnn(inter_rnn) # [BF, -1, H]
|
| 401 |
+
inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1]
|
| 402 |
+
inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T]
|
| 403 |
+
inter_rnn = inter_rnn + input_ # [BQ* NUM_CHUNK, C, T]
|
| 404 |
+
|
| 405 |
+
inter_rnn = inter_rnn.reshape(B, Q, NUM_CHUNK, C, CHUNKSIZE)
|
| 406 |
+
inter_rnn = inter_rnn.permute(0, 1, 2, 4, 3) # B, Q, NUM_CHUNK, CHUNKSIZE, C
|
| 407 |
+
|
| 408 |
+
input_ = inter_rnn # B, Q, NUM_CHUNK, CHUNKSIZE, C
|
| 409 |
+
if self.pool == "mean":
|
| 410 |
+
inter_rnn = torch.mean(inter_rnn, dim=3) # B, Q, NUM_CHUNK, C
|
| 411 |
+
elif self.pool == "max":
|
| 412 |
+
inter_rnn, _ = torch.max(inter_rnn, dim=3) # B, Q, NUM_CHUNK, C
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError("INvalid pool type!")
|
| 415 |
+
# ===========================Inter RNN end================================
|
| 416 |
+
|
| 417 |
+
# ===========================attention start================================
|
| 418 |
+
inter_rnn = inter_rnn.transpose(1, 2) # B, NUM_CHUNK, Q, C
|
| 419 |
+
inter_rnn = self.pool_atten_causal(inter_rnn) # B T Q C
|
| 420 |
+
inter_rnn = inter_rnn.transpose(1, 2) # B Q T C
|
| 421 |
+
|
| 422 |
+
if self.last == True:
|
| 423 |
+
return inter_rnn, init_state
|
| 424 |
+
|
| 425 |
+
else:
|
| 426 |
+
inter_rnn = inter_rnn.unsqueeze(3)
|
| 427 |
+
inter_rnn = input_ + inter_rnn # B, Q, NUM_CHUNK, CHUNKSIZE, C
|
| 428 |
+
|
| 429 |
+
inter_rnn = inter_rnn.reshape(B, Q, T, C)
|
| 430 |
+
inter_rnn = inter_rnn.permute(0, 3, 2, 1) # B C T Q
|
| 431 |
+
inter_rnn = inter_rnn[..., :old_T, :]
|
| 432 |
+
# ===========================attention end================================
|
| 433 |
+
|
| 434 |
+
return inter_rnn, init_state
|
src/models/blocks/model2_block.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from espnet2.torch_utils.get_layer_from_string import get_layer
|
| 10 |
+
from torch.nn import init
|
| 11 |
+
from torch.nn.parameter import Parameter
|
| 12 |
+
import src.utils as utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Lambda(nn.Module):
|
| 16 |
+
def __init__(self, lambd):
|
| 17 |
+
super().__init__()
|
| 18 |
+
import types
|
| 19 |
+
|
| 20 |
+
assert type(lambd) is types.LambdaType
|
| 21 |
+
self.lambd = lambd
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.lambd(x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LayerNormPermuted(nn.LayerNorm):
|
| 28 |
+
def __init__(self, *args, **kwargs):
|
| 29 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
x: [B, C, T, F]
|
| 35 |
+
"""
|
| 36 |
+
x = x.permute(0, 2, 3, 1) # [B, T, F, C]
|
| 37 |
+
x = super().forward(x)
|
| 38 |
+
x = x.permute(0, 3, 1, 2) # [B, C, T, F]
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Use native layernorm implementation
|
| 43 |
+
class LayerNormalization4D(nn.Module):
|
| 44 |
+
def __init__(self, C, eps=1e-5, preserve_outdim=False):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.norm = nn.LayerNorm(C, eps=eps)
|
| 47 |
+
self.preserve_outdim = preserve_outdim
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor):
|
| 50 |
+
"""
|
| 51 |
+
input: (*, C)
|
| 52 |
+
"""
|
| 53 |
+
x = self.norm(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LayerNormalization4DCF(nn.Module):
|
| 58 |
+
def __init__(self, input_dimension, eps=1e-5):
|
| 59 |
+
assert len(input_dimension) == 2
|
| 60 |
+
Q, C = input_dimension
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.norm = nn.LayerNorm((Q * C), eps=eps)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor):
|
| 65 |
+
"""
|
| 66 |
+
input: (B, T, Q * C)
|
| 67 |
+
"""
|
| 68 |
+
x = self.norm(x)
|
| 69 |
+
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LayerNormalization4D_old(nn.Module):
|
| 74 |
+
def __init__(self, input_dimension, eps=1e-5):
|
| 75 |
+
super().__init__()
|
| 76 |
+
param_size = [1, input_dimension, 1, 1]
|
| 77 |
+
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
| 78 |
+
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
| 79 |
+
init.ones_(self.gamma)
|
| 80 |
+
init.zeros_(self.beta)
|
| 81 |
+
self.eps = eps
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if x.ndim == 4:
|
| 85 |
+
_, C, _, _ = x.shape
|
| 86 |
+
stat_dim = (1,)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
|
| 89 |
+
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
|
| 90 |
+
std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F]
|
| 91 |
+
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
| 92 |
+
return x_hat
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def mod_pad(x, chunk_size, pad):
|
| 96 |
+
# Mod pad the rminput to perform integer number of
|
| 97 |
+
# inferences
|
| 98 |
+
mod = 0
|
| 99 |
+
if (x.shape[-1] % chunk_size) != 0:
|
| 100 |
+
mod = chunk_size - (x.shape[-1] % chunk_size)
|
| 101 |
+
|
| 102 |
+
x = F.pad(x, (0, mod))
|
| 103 |
+
x = F.pad(x, pad)
|
| 104 |
+
|
| 105 |
+
return x, mod
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Attention_STFT_causal(nn.Module):
|
| 109 |
+
def __getitem__(self, key):
|
| 110 |
+
return getattr(self, key)
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
emb_dim,
|
| 115 |
+
n_freqs,
|
| 116 |
+
approx_qk_dim=512,
|
| 117 |
+
n_head=4,
|
| 118 |
+
activation="prelu",
|
| 119 |
+
eps=1e-5,
|
| 120 |
+
skip_conn=True,
|
| 121 |
+
use_flash_attention=False,
|
| 122 |
+
dim_feedforward=-1,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000)
|
| 126 |
+
|
| 127 |
+
self.skip_conn = skip_conn
|
| 128 |
+
self.n_freqs = n_freqs
|
| 129 |
+
self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
|
| 130 |
+
self.n_head = n_head
|
| 131 |
+
self.V_dim = emb_dim // n_head
|
| 132 |
+
self.emb_dim = emb_dim
|
| 133 |
+
assert emb_dim % n_head == 0
|
| 134 |
+
E = self.E
|
| 135 |
+
|
| 136 |
+
self.add_module(
|
| 137 |
+
"attn_conv_Q",
|
| 138 |
+
nn.Sequential(
|
| 139 |
+
nn.Linear(emb_dim, E * n_head), # [B, T, Q, HE]
|
| 140 |
+
get_layer(activation)(),
|
| 141 |
+
# [B, T, Q, H, E] -> [B, H, T, Q, E] -> [B * H, T, Q * E]
|
| 142 |
+
Lambda(
|
| 143 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
|
| 144 |
+
.permute(0, 3, 1, 2, 4)
|
| 145 |
+
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
|
| 146 |
+
), # (BH, T, Q * E)
|
| 147 |
+
LayerNormalization4DCF((n_freqs, E), eps=eps),
|
| 148 |
+
),
|
| 149 |
+
)
|
| 150 |
+
self.add_module(
|
| 151 |
+
"attn_conv_K",
|
| 152 |
+
nn.Sequential(
|
| 153 |
+
nn.Linear(emb_dim, E * n_head),
|
| 154 |
+
get_layer(activation)(),
|
| 155 |
+
Lambda(
|
| 156 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E)
|
| 157 |
+
.permute(0, 3, 1, 2, 4)
|
| 158 |
+
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E)
|
| 159 |
+
),
|
| 160 |
+
LayerNormalization4DCF((n_freqs, E), eps=eps),
|
| 161 |
+
),
|
| 162 |
+
)
|
| 163 |
+
self.add_module(
|
| 164 |
+
"attn_conv_V",
|
| 165 |
+
nn.Sequential(
|
| 166 |
+
nn.Linear(emb_dim, (emb_dim // n_head) * n_head),
|
| 167 |
+
get_layer(activation)(),
|
| 168 |
+
Lambda(
|
| 169 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head))
|
| 170 |
+
.permute(0, 3, 1, 2, 4)
|
| 171 |
+
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head))
|
| 172 |
+
),
|
| 173 |
+
LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps),
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.dim_feedforward = dim_feedforward
|
| 178 |
+
|
| 179 |
+
if dim_feedforward == -1:
|
| 180 |
+
self.add_module(
|
| 181 |
+
"attn_concat_proj",
|
| 182 |
+
nn.Sequential(
|
| 183 |
+
nn.Linear(emb_dim, emb_dim),
|
| 184 |
+
get_layer(activation)(),
|
| 185 |
+
Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])),
|
| 186 |
+
LayerNormalization4DCF((n_freqs, emb_dim), eps=eps),
|
| 187 |
+
),
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
self.linear1 = nn.Linear(emb_dim, dim_feedforward)
|
| 191 |
+
self.dropout = nn.Dropout(p=0.1)
|
| 192 |
+
self.activation = nn.ReLU()
|
| 193 |
+
self.linear2 = nn.Linear(dim_feedforward, emb_dim)
|
| 194 |
+
self.dropout2 = nn.Dropout(p=0.1)
|
| 195 |
+
self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps)
|
| 196 |
+
|
| 197 |
+
def _ff_block(self, x):
|
| 198 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 199 |
+
return self.dropout2(x)
|
| 200 |
+
|
| 201 |
+
def get_lookahead_mask(self, seq_len, device):
|
| 202 |
+
"""Creates a binary mask for each sequence which masks future frames.
|
| 203 |
+
Arguments
|
| 204 |
+
---------
|
| 205 |
+
seq_len: int
|
| 206 |
+
Length of the sequence.
|
| 207 |
+
device: torch.device
|
| 208 |
+
The device on which to create the mask.
|
| 209 |
+
Example
|
| 210 |
+
-------
|
| 211 |
+
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
|
| 212 |
+
>>> get_lookahead_mask(a.shape[1], device)
|
| 213 |
+
tensor([[0., -inf, -inf],
|
| 214 |
+
[0., 0., -inf],
|
| 215 |
+
[0., 0., 0.]])
|
| 216 |
+
"""
|
| 217 |
+
mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
|
| 218 |
+
|
| 219 |
+
return mask.detach().to(device)
|
| 220 |
+
|
| 221 |
+
def forward(self, batch):
|
| 222 |
+
### input/output B T F C
|
| 223 |
+
# attention
|
| 224 |
+
inputs = batch
|
| 225 |
+
B0, T0, Q0, C0 = batch.shape
|
| 226 |
+
# print("dim of just entering attention stft causal is {}".format(batch.shape))
|
| 227 |
+
# [2, 12, 133, 16]
|
| 228 |
+
|
| 229 |
+
# positional encoding
|
| 230 |
+
pos_code = self.position_code(batch) # 1, T, embed_dim
|
| 231 |
+
# print("pos_code", pos_code.shape)
|
| 232 |
+
_, T, QC = pos_code.shape
|
| 233 |
+
pos_code = pos_code.reshape(1, T, Q0, C0)
|
| 234 |
+
batch = batch + pos_code
|
| 235 |
+
|
| 236 |
+
# print("shape of q is {}".format(Q.shape))
|
| 237 |
+
# print("batch shape is {}".format(batch.shape)) [1, 4800, 16, 133]
|
| 238 |
+
|
| 239 |
+
Q = self["attn_conv_Q"](batch) # [B', T, Q * C]
|
| 240 |
+
K = self["attn_conv_K"](batch) # [B', T, Q * C]
|
| 241 |
+
V = self["attn_conv_V"](batch) # [B', T, Q * C]
|
| 242 |
+
|
| 243 |
+
emb_dim = Q.shape[-1]
|
| 244 |
+
|
| 245 |
+
local_mask = self.get_lookahead_mask(batch.shape[1], batch.device)
|
| 246 |
+
|
| 247 |
+
attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
|
| 248 |
+
attn_mat.masked_fill_(local_mask == 0, -float("Inf"))
|
| 249 |
+
attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
|
| 250 |
+
|
| 251 |
+
V = torch.matmul(attn_mat, V) # [B', T, Q*C]
|
| 252 |
+
V = V.reshape(-1, T0, V.shape[-1]) # [BH, T, Q * C]
|
| 253 |
+
V = V.transpose(1, 2) # [B', Q * C, T]
|
| 254 |
+
|
| 255 |
+
batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) # [B, H, Q, C, T]
|
| 256 |
+
batch = batch.transpose(2, 3) # [B, H, C, Q, T]
|
| 257 |
+
batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) # [B, HC, Q, T]
|
| 258 |
+
batch = batch.permute(0, 3, 2, 1) # [B, T, Q, C]
|
| 259 |
+
|
| 260 |
+
if self.dim_feedforward == -1:
|
| 261 |
+
batch = self["attn_concat_proj"](batch) # [B, T, Q * C]
|
| 262 |
+
else:
|
| 263 |
+
batch = batch + self._ff_block(batch) # [B, T, Q, C]
|
| 264 |
+
batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3])
|
| 265 |
+
batch = self.norm(batch)
|
| 266 |
+
batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) # [B, T, Q, C])
|
| 267 |
+
|
| 268 |
+
# print("dim of output of attention stft causal is {}".format(batch.shape))
|
| 269 |
+
# [2, 12, 133, 16]
|
| 270 |
+
|
| 271 |
+
# Add batch if attention is performed
|
| 272 |
+
if self.skip_conn:
|
| 273 |
+
return batch + inputs
|
| 274 |
+
else:
|
| 275 |
+
return batch
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class GridNetBlock(nn.Module):
|
| 279 |
+
def __getitem__(self, key):
|
| 280 |
+
return getattr(self, key)
|
| 281 |
+
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
emb_dim,
|
| 285 |
+
emb_ks,
|
| 286 |
+
emb_hs,
|
| 287 |
+
n_freqs,
|
| 288 |
+
hidden_channels,
|
| 289 |
+
n_head=4,
|
| 290 |
+
approx_qk_dim=512,
|
| 291 |
+
activation="prelu",
|
| 292 |
+
eps=1e-5,
|
| 293 |
+
pool="mean",
|
| 294 |
+
use_attention=False,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
bidirectional = False
|
| 298 |
+
|
| 299 |
+
self.global_atten_causal = True
|
| 300 |
+
|
| 301 |
+
self.pool = pool
|
| 302 |
+
|
| 303 |
+
self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
|
| 304 |
+
|
| 305 |
+
self.V_dim = emb_dim // n_head
|
| 306 |
+
self.H = hidden_channels
|
| 307 |
+
in_channels = emb_dim * emb_ks
|
| 308 |
+
self.in_channels = in_channels
|
| 309 |
+
self.n_freqs = n_freqs
|
| 310 |
+
|
| 311 |
+
## intra RNN can be optimized by conv or linear because the frequence length are not very large
|
| 312 |
+
self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps)
|
| 313 |
+
self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True)
|
| 314 |
+
self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs)
|
| 315 |
+
self.emb_dim = emb_dim
|
| 316 |
+
self.emb_ks = emb_ks
|
| 317 |
+
self.emb_hs = emb_hs
|
| 318 |
+
|
| 319 |
+
# inter RNN
|
| 320 |
+
self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps)
|
| 321 |
+
self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional)
|
| 322 |
+
self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs)
|
| 323 |
+
|
| 324 |
+
# attention
|
| 325 |
+
self.use_attention = use_attention
|
| 326 |
+
|
| 327 |
+
if self.use_attention:
|
| 328 |
+
self.pool_atten_causal = Attention_STFT_causal(
|
| 329 |
+
emb_dim=emb_dim,
|
| 330 |
+
n_freqs=n_freqs,
|
| 331 |
+
approx_qk_dim=approx_qk_dim,
|
| 332 |
+
n_head=n_head,
|
| 333 |
+
activation=activation,
|
| 334 |
+
eps=eps,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def init_buffers(self, batch_size, device):
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
# def _unfold_timedomain(self, x):
|
| 341 |
+
# BQ, C, T= x.shape
|
| 342 |
+
# # print("shape of x is {}".format(x.shape))
|
| 343 |
+
# # [117, 16, 4801] for causality testing
|
| 344 |
+
# # 4800 if training
|
| 345 |
+
# x = torch.split(x, self.lstm_fold_chunk, dim=-1) # [Num_chunk, BQ, C, 100]
|
| 346 |
+
# x = torch.cat(x, dim=0).reshape(-1, BQ, C, self.lstm_fold_chunk) # [Num_chunk, BQ, C, 100]
|
| 347 |
+
# x = x.permute(1, 0, 3, 2) # [BQ, Num_chunk, 100, C]
|
| 348 |
+
# return x
|
| 349 |
+
|
| 350 |
+
def forward(self, x, init_state=None):
|
| 351 |
+
"""GridNetBlock Forward.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
x: [B, C, T, Q]
|
| 355 |
+
out: [B, C, T, Q]
|
| 356 |
+
"""
|
| 357 |
+
B, C, old_T, old_Q = x.shape
|
| 358 |
+
# print("shape of x is {}".format(x.shape))
|
| 359 |
+
# print("old q is {}".format(old_Q))
|
| 360 |
+
# print("dim just entered grid net block is {}".format(x.shape))
|
| 361 |
+
# [1, 16, 4801, 117]
|
| 362 |
+
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
| 363 |
+
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
| 364 |
+
x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
|
| 365 |
+
|
| 366 |
+
# ===========================Intra RNN start================================
|
| 367 |
+
# define intra RNN
|
| 368 |
+
input_ = x
|
| 369 |
+
intra_rnn = self.intra_norm(input_) # [B, C, T, Q]
|
| 370 |
+
intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
|
| 371 |
+
|
| 372 |
+
intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) # [Q/I, BT, C, I]
|
| 373 |
+
intra_rnn = torch.stack(intra_rnn, dim=0)
|
| 374 |
+
intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) # [BT, CI, Q/I]
|
| 375 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, nC*emb_ks]
|
| 376 |
+
self.intra_rnn.flatten_parameters()
|
| 377 |
+
|
| 378 |
+
# apply intra frame LSTM
|
| 379 |
+
intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
|
| 380 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
|
| 381 |
+
intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
|
| 382 |
+
intra_rnn = intra_rnn.view([B, T, C, Q])
|
| 383 |
+
intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q]
|
| 384 |
+
intra_rnn = intra_rnn + input_ # [B, C, T, Q]
|
| 385 |
+
intra_rnn = intra_rnn[:, :, :, :old_Q] # [B, C, T, Q]
|
| 386 |
+
Q = old_Q
|
| 387 |
+
|
| 388 |
+
# ===========================Intra RNN end================================
|
| 389 |
+
|
| 390 |
+
# print("dim after intra rnn is {}".format(intra_rnn.shape))
|
| 391 |
+
# [1, 16, 4801, 117]
|
| 392 |
+
# [B, C, T, Q]
|
| 393 |
+
|
| 394 |
+
# inter_rnn=intra_rnn
|
| 395 |
+
# ===========================Inter RNN start================================
|
| 396 |
+
# fold the time domain to chunk
|
| 397 |
+
input_ = intra_rnn
|
| 398 |
+
|
| 399 |
+
inter_rnn = self.inter_norm(intra_rnn) # [B, C, T, Q]
|
| 400 |
+
inter_rnn = inter_rnn.transpose(1, 3).reshape(B * Q, T, C)
|
| 401 |
+
# inter_rnn = (
|
| 402 |
+
# inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
|
| 403 |
+
# ) # [BF, C, T]
|
| 404 |
+
|
| 405 |
+
# print("dim of inter rnn is {}".format(inter_rnn.shape))
|
| 406 |
+
# [117, 16, 4801]
|
| 407 |
+
|
| 408 |
+
self.inter_rnn.flatten_parameters()
|
| 409 |
+
# print("inter rnn shape is {}".format(inter_rnn.shape))
|
| 410 |
+
# [133, 400, 16]
|
| 411 |
+
inter_rnn, _ = self.inter_rnn(inter_rnn) # [B * Q, -1, H]
|
| 412 |
+
inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1]
|
| 413 |
+
inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T]
|
| 414 |
+
|
| 415 |
+
_, new_C, new_T = inter_rnn.shape
|
| 416 |
+
inter_rnn = inter_rnn.reshape(B, Q, new_C, new_T)
|
| 417 |
+
inter_rnn = inter_rnn.permute(0, 2, 3, 1)
|
| 418 |
+
# print("shape of inter rnn is {}".format(inter_rnn.shape)) # [133, 16, 4800]
|
| 419 |
+
# print("shape of input_ is {}".format(input_.shape)) # [1, 16, 4800, 133]
|
| 420 |
+
inter_rnn = inter_rnn + input_
|
| 421 |
+
# ===========================Inter RNN end================================
|
| 422 |
+
|
| 423 |
+
# inter rnn shape is [B, C, T, Q]
|
| 424 |
+
|
| 425 |
+
# ===========================attention start================================
|
| 426 |
+
if self.use_attention:
|
| 427 |
+
out = inter_rnn # [B, C, T, Q]
|
| 428 |
+
|
| 429 |
+
inter_rnn = inter_rnn.permute(0, 2, 3, 1)
|
| 430 |
+
inter_rnn = self.pool_atten_causal(inter_rnn) # B T Q C
|
| 431 |
+
inter_rnn = inter_rnn.permute(0, 3, 1, 2) # [B, C, T, Q]
|
| 432 |
+
inter_rnn = out + inter_rnn # B, C, T, Q
|
| 433 |
+
|
| 434 |
+
# Output is inter_rnn by default
|
| 435 |
+
# inter_rnn = inter_rnn.reshape(B, Q, T, C)
|
| 436 |
+
# inter_rnn = inter_rnn.permute(0, 3, 2, 1) # B C T Q
|
| 437 |
+
inter_rnn = inter_rnn[..., :old_T, :]
|
| 438 |
+
# ===========================attention end================================
|
| 439 |
+
|
| 440 |
+
# print("final output inter rnn dimension is {}".format(inter_rnn.shape))
|
| 441 |
+
# print("old T is {}".format(old_T))
|
| 442 |
+
|
| 443 |
+
# print("final output dimension is {}".format(inter_rnn.shape))
|
| 444 |
+
# [2, 16, 4800, 133] [B, C, T, Q]
|
| 445 |
+
|
| 446 |
+
# return inter_rnn, init_state#, [t0 - t0_0, t1 - t0, t2 - t2_0, t3 - t2, t5 - t4, t7 - t6]
|
| 447 |
+
# else:
|
| 448 |
+
return inter_rnn, init_state
|
src/models/network/model1.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import src.utils as utils
|
| 4 |
+
# from src.models.common.film import FiLM
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FilmLayer(nn.Module):
|
| 8 |
+
def __init__(self, D, C, nF, groups = 1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.D = D # speaker dim 256
|
| 11 |
+
self.C = C # latent dim 16
|
| 12 |
+
self.nF = nF
|
| 13 |
+
self.weight = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
|
| 14 |
+
self.bias = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor, embedding: torch.Tensor):
|
| 17 |
+
"""
|
| 18 |
+
x: (B, D, F, T)
|
| 19 |
+
embedding: (B, D, F)
|
| 20 |
+
"""
|
| 21 |
+
B, D, _F, T = x.shape
|
| 22 |
+
|
| 23 |
+
w = self.weight(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
|
| 24 |
+
b = self.bias(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
|
| 25 |
+
|
| 26 |
+
return x * w + b
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LayerNormPermuted(nn.LayerNorm):
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
x: [B, C, T, F]
|
| 37 |
+
"""
|
| 38 |
+
x = x.permute(0, 2, 3, 1) # [B, T, F, C]
|
| 39 |
+
x = super().forward(x)
|
| 40 |
+
x = x.permute(0, 3, 1, 2) # [B, C, T, F]
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Conv_Emb_Generator(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
block_model_name,
|
| 48 |
+
block_model_params,
|
| 49 |
+
spk_dim=256,
|
| 50 |
+
n_srcs=1,
|
| 51 |
+
n_fft=128,
|
| 52 |
+
latent_dim=16,
|
| 53 |
+
num_inputs=1,
|
| 54 |
+
n_layers=6,
|
| 55 |
+
use_first_ln=True,
|
| 56 |
+
n_imics=1,
|
| 57 |
+
lstm_fold_chunk=400,
|
| 58 |
+
E=2,
|
| 59 |
+
use_speaker_emb=True,
|
| 60 |
+
one_emb=True,
|
| 61 |
+
local_context_len=-1
|
| 62 |
+
# 6
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.n_srcs = n_srcs
|
| 66 |
+
self.n_layers = n_layers
|
| 67 |
+
self.num_inputs = num_inputs
|
| 68 |
+
assert n_fft % 2 == 0
|
| 69 |
+
n_freqs = n_fft // 2 + 1
|
| 70 |
+
self.n_freqs = n_freqs
|
| 71 |
+
self.latent_dim = latent_dim
|
| 72 |
+
|
| 73 |
+
self.use_speaker_emb=use_speaker_emb
|
| 74 |
+
self.one_emb=one_emb
|
| 75 |
+
|
| 76 |
+
attn_approx_qk_dim=E*n_freqs
|
| 77 |
+
|
| 78 |
+
self.n_fft = n_fft
|
| 79 |
+
|
| 80 |
+
self.eps=1.0e-5
|
| 81 |
+
|
| 82 |
+
t_ksize = 3
|
| 83 |
+
self.t_ksize = t_ksize
|
| 84 |
+
ks, padding = (t_ksize, t_ksize), (0, 1)
|
| 85 |
+
|
| 86 |
+
self.n_imics=n_imics
|
| 87 |
+
if not use_speaker_emb:
|
| 88 |
+
self.n_imics=self.n_imics+1
|
| 89 |
+
|
| 90 |
+
module_list = [nn.Conv2d(2*self.n_imics, latent_dim, ks, padding=padding)]
|
| 91 |
+
|
| 92 |
+
if use_first_ln:
|
| 93 |
+
module_list.append(LayerNormPermuted(latent_dim))
|
| 94 |
+
|
| 95 |
+
self.conv = nn.Sequential(
|
| 96 |
+
*module_list
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# FiLM layer
|
| 100 |
+
self.embeds = nn.ModuleList([])
|
| 101 |
+
|
| 102 |
+
self.local_context_len=local_context_len
|
| 103 |
+
|
| 104 |
+
self.blocks = nn.ModuleList([])
|
| 105 |
+
for _i in range(n_layers-1):
|
| 106 |
+
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, last=False, local_context_len=local_context_len, **block_model_params))
|
| 107 |
+
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, approx_qk_dim=attn_approx_qk_dim, lstm_fold_chunk=lstm_fold_chunk, local_context_len=local_context_len, last=True, **block_model_params))
|
| 108 |
+
|
| 109 |
+
if self.use_speaker_emb and not self.one_emb:
|
| 110 |
+
for _i in range(n_layers-1):
|
| 111 |
+
self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1))
|
| 112 |
+
elif self.use_speaker_emb and self.one_emb:
|
| 113 |
+
self.embeds.append(FilmLayer(spk_dim, latent_dim, n_freqs, 1))
|
| 114 |
+
|
| 115 |
+
def init_buffers(self, batch_size, device):
|
| 116 |
+
conv_buf = torch.zeros(batch_size, 2*self.n_imics, self.t_ksize - 1, self.n_freqs,
|
| 117 |
+
device=device)
|
| 118 |
+
|
| 119 |
+
deconv_buf = torch.zeros(batch_size, self.latent_dim, self.t_ksize - 1, self.n_freqs,
|
| 120 |
+
device=device)
|
| 121 |
+
|
| 122 |
+
block_buffers = {}
|
| 123 |
+
for i in range(len(self.blocks)):
|
| 124 |
+
block_buffers[f'buf{i}'] = None
|
| 125 |
+
|
| 126 |
+
return dict(conv_buf=conv_buf, deconv_buf=deconv_buf,
|
| 127 |
+
block_bufs=block_buffers)
|
| 128 |
+
|
| 129 |
+
def forward(self, current_input: torch.Tensor, embedding: torch.Tensor, input_state, quantized=False) -> torch.Tensor:
|
| 130 |
+
"""
|
| 131 |
+
B: batch, M: mic, F: freq bin, C: real/imag, T: time frame
|
| 132 |
+
D: dimension of the embedding vector
|
| 133 |
+
current_input: (B, CM, T, F)
|
| 134 |
+
embedding: (B, D)
|
| 135 |
+
output: (B, S, T, C*F)
|
| 136 |
+
"""
|
| 137 |
+
# [B, C, T, F]
|
| 138 |
+
n_batch, _, n_frames, n_freqs = current_input.shape
|
| 139 |
+
batch = current_input
|
| 140 |
+
|
| 141 |
+
if input_state is None:
|
| 142 |
+
input_state = self.init_buffers(current_input.shape[0], current_input.device)
|
| 143 |
+
|
| 144 |
+
conv_buf = input_state['conv_buf']
|
| 145 |
+
gridnet_buf = input_state['block_bufs']
|
| 146 |
+
|
| 147 |
+
if quantized:
|
| 148 |
+
batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
|
| 149 |
+
else:
|
| 150 |
+
batch = torch.cat((conv_buf, batch), dim=2)
|
| 151 |
+
|
| 152 |
+
conv_buf = batch[:, :, -(self.t_ksize - 1):, :]
|
| 153 |
+
batch = self.conv(batch) # [B, D, T, F]
|
| 154 |
+
|
| 155 |
+
if self.use_speaker_emb:
|
| 156 |
+
if not self.one_emb:
|
| 157 |
+
assert len(self.blocks)==self.n_layers
|
| 158 |
+
assert len(self.embeds)==self.n_layers-1
|
| 159 |
+
for ii in range(self.n_layers-1):
|
| 160 |
+
batch = batch.transpose(2, 3)
|
| 161 |
+
if ii > 0:
|
| 162 |
+
batch = self.embeds[ii - 1](batch, embedding)
|
| 163 |
+
batch = batch.transpose(2, 3)
|
| 164 |
+
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
|
| 165 |
+
|
| 166 |
+
batch = batch.transpose(2, 3)
|
| 167 |
+
batch = self.embeds[-1](batch, embedding)
|
| 168 |
+
batch = batch.transpose(2, 3)
|
| 169 |
+
batch, gridnet_buf[f'buf{self.n_layers-1}'] = self.blocks[self.n_layers-1](batch, gridnet_buf[f'buf{self.n_layers-1}'])
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
assert len(self.blocks)==self.n_layers
|
| 173 |
+
assert len(self.embeds)==1
|
| 174 |
+
for ii in range(self.n_layers):
|
| 175 |
+
batch = batch.transpose(2, 3)
|
| 176 |
+
if ii == 1:
|
| 177 |
+
batch = self.embeds[ii - 1](batch, embedding)
|
| 178 |
+
batch = batch.transpose(2, 3)
|
| 179 |
+
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
assert len(self.blocks)==self.n_layers
|
| 183 |
+
for ii in range(self.n_layers):
|
| 184 |
+
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
|
| 185 |
+
|
| 186 |
+
conversation_emb=batch
|
| 187 |
+
|
| 188 |
+
return conversation_emb, input_state
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def edge_mode(self):
|
| 192 |
+
for i in range(len(self.blocks)):
|
| 193 |
+
self.blocks[i].edge_mode()
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
pass
|
src/models/network/model2_joint.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import src.utils as utils
|
| 4 |
+
# from src.models.common.film import FiLM
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FilmLayer(nn.Module):
|
| 8 |
+
def __init__(self, D, C, nF, groups = 1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.D = D
|
| 11 |
+
self.C = C
|
| 12 |
+
self.nF = nF
|
| 13 |
+
self.weight = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
|
| 14 |
+
self.bias = nn.Conv1d(self.D, self.C * nF, 1, groups = groups)
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor, embedding: torch.Tensor):
|
| 17 |
+
"""
|
| 18 |
+
x: (B, D, F, T)
|
| 19 |
+
embedding: (B, D, F)
|
| 20 |
+
"""
|
| 21 |
+
B, D, _F, T = x.shape
|
| 22 |
+
w = self.weight(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
|
| 23 |
+
b = self.bias(embedding).reshape(B, self.C, _F, 1) # (B, C, F, 1)
|
| 24 |
+
|
| 25 |
+
return x * w + b
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LayerNormPermuted(nn.LayerNorm):
|
| 29 |
+
def __init__(self, *args, **kwargs):
|
| 30 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
x: [B, C, T, F]
|
| 36 |
+
"""
|
| 37 |
+
x = x.permute(0, 2, 3, 1) # [B, T, F, C]
|
| 38 |
+
x = super().forward(x)
|
| 39 |
+
x = x.permute(0, 3, 1, 2) # [B, C, T, F]
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TSH(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
block_model_name,
|
| 47 |
+
block_model_params,
|
| 48 |
+
spk_dim=256,
|
| 49 |
+
latent_dim=48,
|
| 50 |
+
n_srcs=1,
|
| 51 |
+
n_fft=128,
|
| 52 |
+
num_inputs=1,
|
| 53 |
+
n_layers=6,
|
| 54 |
+
use_first_ln=True,
|
| 55 |
+
n_imics=1,
|
| 56 |
+
lstm_fold_chunk=400,
|
| 57 |
+
stft_chunk_size=200,
|
| 58 |
+
latent_dim_model1=16,
|
| 59 |
+
use_speaker_emb=True,
|
| 60 |
+
use_self_speech_model2=True
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.n_srcs = n_srcs
|
| 64 |
+
self.n_layers = n_layers
|
| 65 |
+
self.num_inputs = num_inputs
|
| 66 |
+
assert n_fft % 2 == 0
|
| 67 |
+
n_freqs = n_fft // 2 + 1
|
| 68 |
+
self.n_freqs = n_freqs
|
| 69 |
+
self.latent_dim = latent_dim
|
| 70 |
+
self.lstm_fold_chunk=lstm_fold_chunk
|
| 71 |
+
self.stft_chunk_size=stft_chunk_size
|
| 72 |
+
|
| 73 |
+
self.n_fft = n_fft
|
| 74 |
+
|
| 75 |
+
self.eps=1.0e-5
|
| 76 |
+
|
| 77 |
+
t_ksize = 3
|
| 78 |
+
self.t_ksize = t_ksize
|
| 79 |
+
ks, padding = (t_ksize, t_ksize), (0, 1)
|
| 80 |
+
|
| 81 |
+
self.n_imics=n_imics
|
| 82 |
+
|
| 83 |
+
self.use_self_speech_model2=use_self_speech_model2
|
| 84 |
+
|
| 85 |
+
if not use_speaker_emb and use_self_speech_model2:
|
| 86 |
+
self.n_imics=self.n_imics+1
|
| 87 |
+
|
| 88 |
+
module_list = [nn.Conv2d(2*self.n_imics, latent_dim, ks, padding=padding)]
|
| 89 |
+
|
| 90 |
+
if use_first_ln:
|
| 91 |
+
module_list.append(LayerNormPermuted(latent_dim))
|
| 92 |
+
|
| 93 |
+
self.conv = nn.Sequential(
|
| 94 |
+
*module_list
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# FiLM layer
|
| 99 |
+
self.embeds = nn.ModuleList([])
|
| 100 |
+
|
| 101 |
+
# Process through a stack of blocks
|
| 102 |
+
self.blocks = nn.ModuleList([])
|
| 103 |
+
for _i in range(n_layers):
|
| 104 |
+
self.blocks.append(utils.import_attr(block_model_name)(emb_dim=latent_dim, n_freqs=n_freqs, **block_model_params))
|
| 105 |
+
|
| 106 |
+
# Project back to TF-Domain
|
| 107 |
+
self.deconv = nn.ConvTranspose2d(latent_dim, n_srcs * 2, ks, padding=( self.t_ksize - 1, 1))
|
| 108 |
+
|
| 109 |
+
self.latent_dim_model1=latent_dim_model1
|
| 110 |
+
|
| 111 |
+
if latent_dim_model1!=latent_dim:
|
| 112 |
+
self.projection_layer = nn.Conv2d(latent_dim_model1, latent_dim, kernel_size=1)
|
| 113 |
+
|
| 114 |
+
def init_buffers(self, batch_size, device):
|
| 115 |
+
conv_buf = torch.zeros(batch_size, 2*self.n_imics, self.t_ksize - 1, self.n_freqs,
|
| 116 |
+
device=device)
|
| 117 |
+
|
| 118 |
+
deconv_buf = torch.zeros(batch_size, self.latent_dim, self.t_ksize - 1, self.n_freqs,
|
| 119 |
+
device=device)
|
| 120 |
+
|
| 121 |
+
block_buffers = {}
|
| 122 |
+
for i in range(len(self.blocks)):
|
| 123 |
+
block_buffers[f'buf{i}'] = self.blocks[i].init_buffers(batch_size, device)
|
| 124 |
+
|
| 125 |
+
return dict(conv_buf=conv_buf, deconv_buf=deconv_buf,
|
| 126 |
+
block_bufs=block_buffers)
|
| 127 |
+
|
| 128 |
+
def forward(self, current_input: torch.Tensor, embedding: torch.Tensor, input_state, quantized=False) -> torch.Tensor:
|
| 129 |
+
"""
|
| 130 |
+
B: batch, M: mic, F: freq bin, C: real/imag, T: time frame
|
| 131 |
+
D: dimension of the embedding vector
|
| 132 |
+
current_input: (B, CM, T, F)
|
| 133 |
+
embedding: (B, D, F)
|
| 134 |
+
output: (B, S, T, C*F)
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
n_batch, _, n_frames, n_freqs = current_input.shape
|
| 138 |
+
batch = current_input
|
| 139 |
+
|
| 140 |
+
if input_state is None:
|
| 141 |
+
input_state = self.init_buffers(current_input.shape[0], current_input.device)
|
| 142 |
+
|
| 143 |
+
conv_buf = input_state['conv_buf']
|
| 144 |
+
gridnet_buf = input_state['block_bufs']
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if quantized:
|
| 148 |
+
batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
|
| 149 |
+
else:
|
| 150 |
+
batch = torch.cat((conv_buf, batch), dim=2)
|
| 151 |
+
|
| 152 |
+
conv_buf = batch[:, :, -(self.t_ksize - 1):, :]
|
| 153 |
+
batch = self.conv(batch) # [B, D, T, F]
|
| 154 |
+
|
| 155 |
+
embedding=embedding.transpose(1, 3)
|
| 156 |
+
|
| 157 |
+
for ii in range(self.n_layers):
|
| 158 |
+
if ii==1:
|
| 159 |
+
batch=batch*embedding
|
| 160 |
+
batch, gridnet_buf[f'buf{ii}'] = self.blocks[ii](batch, gridnet_buf[f'buf{ii}'])
|
| 161 |
+
|
| 162 |
+
deconv_buf = torch.zeros(n_batch, self.latent_dim, self.t_ksize - 1, self.n_freqs,
|
| 163 |
+
device=current_input.device)
|
| 164 |
+
if quantized:
|
| 165 |
+
batch = nn.functional.pad(batch, (0, 0, self.t_ksize - 1, 0))
|
| 166 |
+
else:
|
| 167 |
+
batch = torch.cat(( deconv_buf, batch), dim=2)
|
| 168 |
+
|
| 169 |
+
batch = self.deconv(batch) # [B, n_srcs*C, T, F]
|
| 170 |
+
|
| 171 |
+
batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs]) # [B, n_srcs, 2, n_frames, n_freqs]
|
| 172 |
+
batch = batch.transpose(2, 3).reshape(n_batch, self.n_srcs, n_frames, 2 * n_freqs) # [B, S, T, F]
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
input_state['conv_buf'] = conv_buf
|
| 176 |
+
input_state['block_bufs'] = gridnet_buf
|
| 177 |
+
|
| 178 |
+
return batch, input_state
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def edge_mode(self):
|
| 182 |
+
for i in range(len(self.blocks)):
|
| 183 |
+
self.blocks[i].edge_mode()
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
pass
|
src/models/network/net_conversation_joint.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .model1 import Conv_Emb_Generator
|
| 5 |
+
from .model2_joint import TSH
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import copy
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def mod_pad(x, chunk_size, pad):
|
| 12 |
+
mod = 0
|
| 13 |
+
if (x.shape[-1] % chunk_size) != 0:
|
| 14 |
+
mod = chunk_size - (x.shape[-1] % chunk_size)
|
| 15 |
+
|
| 16 |
+
x = F.pad(x, (0, mod))
|
| 17 |
+
x = F.pad(x, pad)
|
| 18 |
+
|
| 19 |
+
return x, mod
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# A TF-domain network guided by an embedding vector
|
| 23 |
+
class Net_Conversation(nn.Module):
|
| 24 |
+
def __init__(self,
|
| 25 |
+
model1_block_name,
|
| 26 |
+
model1_block_params,
|
| 27 |
+
model2_block_name,
|
| 28 |
+
model2_block_params,
|
| 29 |
+
stft_chunk_size=64,
|
| 30 |
+
stft_pad_size=32,
|
| 31 |
+
stft_back_pad=32,
|
| 32 |
+
num_input_channels=1,
|
| 33 |
+
num_output_channels=1,
|
| 34 |
+
num_sources=1,
|
| 35 |
+
speaker_embed = 256,
|
| 36 |
+
num_layers_model1=3,
|
| 37 |
+
num_layers_model2=3,
|
| 38 |
+
latent_dim_model1=16,
|
| 39 |
+
latent_dim_model2=32,
|
| 40 |
+
use_sp_feats=False,
|
| 41 |
+
use_first_ln=True,
|
| 42 |
+
n_imics=1,
|
| 43 |
+
window="hann",
|
| 44 |
+
lstm_fold_chunk=400,
|
| 45 |
+
E=2,
|
| 46 |
+
use_speaker_emb_model1=True,
|
| 47 |
+
one_emb_model1=True,
|
| 48 |
+
use_self_speech_model2=True,
|
| 49 |
+
local_context_len=-1
|
| 50 |
+
):
|
| 51 |
+
super(Net_Conversation, self).__init__()
|
| 52 |
+
|
| 53 |
+
assert num_sources == 1
|
| 54 |
+
|
| 55 |
+
# num input/output channels
|
| 56 |
+
self.nI = num_input_channels
|
| 57 |
+
self.nO = num_output_channels
|
| 58 |
+
|
| 59 |
+
# num channels to the TF-network
|
| 60 |
+
num_separator_inputs = self.nI * 2 + use_sp_feats * (3 * (self.nI - 1))
|
| 61 |
+
|
| 62 |
+
self.stft_chunk_size = stft_chunk_size
|
| 63 |
+
self.stft_pad_size = stft_pad_size
|
| 64 |
+
self.stft_back_pad = stft_back_pad
|
| 65 |
+
self.n_srcs = num_sources
|
| 66 |
+
self.use_sp_feats = use_sp_feats
|
| 67 |
+
|
| 68 |
+
# Input conv to convert input audio to a latent representation
|
| 69 |
+
self.nfft = stft_back_pad + stft_chunk_size + stft_pad_size
|
| 70 |
+
|
| 71 |
+
self.nfreqs = self.nfft//2 + 1
|
| 72 |
+
|
| 73 |
+
self.lstm_fold_chunk=lstm_fold_chunk
|
| 74 |
+
|
| 75 |
+
# Construct synthesis/analysis windows (rect)
|
| 76 |
+
if window=="hann":
|
| 77 |
+
window_fn = lambda x: np.hanning(x)
|
| 78 |
+
elif window=="rect":
|
| 79 |
+
window_fn = lambda x: np.ones(x)
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError("Invalid window type!")
|
| 82 |
+
|
| 83 |
+
if ((stft_pad_size) % stft_chunk_size) == 0:
|
| 84 |
+
print("Using perfect STFT windows")
|
| 85 |
+
self.analysis_window = torch.from_numpy(window_fn(self.nfft)).float()
|
| 86 |
+
|
| 87 |
+
# eg. inverse SFTF
|
| 88 |
+
self.synthesis_window = torch.zeros(stft_pad_size + stft_chunk_size).float()
|
| 89 |
+
|
| 90 |
+
A = self.synthesis_window.shape[0]
|
| 91 |
+
B = self.stft_chunk_size
|
| 92 |
+
N = self.analysis_window.shape[0]
|
| 93 |
+
|
| 94 |
+
assert (A % B) == 0
|
| 95 |
+
for i in range(A):
|
| 96 |
+
num = self.analysis_window[N - A + i]
|
| 97 |
+
|
| 98 |
+
denom = 0
|
| 99 |
+
for k in range(A//B):
|
| 100 |
+
denom += (self.analysis_window[N - A + (i % B) + k * B] ** 2)
|
| 101 |
+
|
| 102 |
+
self.synthesis_window[i] = num / denom
|
| 103 |
+
else:
|
| 104 |
+
print("Using imperfect STFT windows")
|
| 105 |
+
self.analysis_window = torch.from_numpy( window_fn(self.nfft) ).float()
|
| 106 |
+
self.synthesis_window = torch.from_numpy( window_fn(stft_chunk_size + stft_pad_size) ).float()
|
| 107 |
+
|
| 108 |
+
self.istft_lookback = 1 + (self.synthesis_window.shape[0] - 1) // self.stft_chunk_size
|
| 109 |
+
|
| 110 |
+
if local_context_len!=-1:
|
| 111 |
+
local_context_len=local_context_len//stft_chunk_size//lstm_fold_chunk
|
| 112 |
+
|
| 113 |
+
self.model1 = Conv_Emb_Generator(
|
| 114 |
+
model1_block_name,
|
| 115 |
+
model1_block_params,
|
| 116 |
+
spk_dim = speaker_embed,
|
| 117 |
+
latent_dim = latent_dim_model1,
|
| 118 |
+
n_srcs = num_output_channels * num_sources,
|
| 119 |
+
n_fft = self.nfft,
|
| 120 |
+
num_inputs = num_separator_inputs,
|
| 121 |
+
n_layers = num_layers_model1,
|
| 122 |
+
use_first_ln=use_first_ln,
|
| 123 |
+
n_imics=n_imics,
|
| 124 |
+
lstm_fold_chunk=lstm_fold_chunk,
|
| 125 |
+
E=E,
|
| 126 |
+
use_speaker_emb=use_speaker_emb_model1,
|
| 127 |
+
one_emb=one_emb_model1,
|
| 128 |
+
local_context_len=local_context_len
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.quantized = False
|
| 132 |
+
|
| 133 |
+
self.use_self_speech_model2=use_self_speech_model2
|
| 134 |
+
|
| 135 |
+
self.model2=TSH(
|
| 136 |
+
model2_block_name,
|
| 137 |
+
model2_block_params,
|
| 138 |
+
spk_dim = speaker_embed,
|
| 139 |
+
latent_dim = latent_dim_model2,
|
| 140 |
+
latent_dim_model1=latent_dim_model1,
|
| 141 |
+
n_srcs = num_output_channels * num_sources,
|
| 142 |
+
n_fft = self.nfft,
|
| 143 |
+
num_inputs = num_separator_inputs,
|
| 144 |
+
n_layers = num_layers_model2,
|
| 145 |
+
use_first_ln=use_first_ln,
|
| 146 |
+
n_imics=n_imics,
|
| 147 |
+
lstm_fold_chunk=lstm_fold_chunk,
|
| 148 |
+
stft_chunk_size=stft_chunk_size,
|
| 149 |
+
use_speaker_emb=use_speaker_emb_model1,
|
| 150 |
+
use_self_speech_model2=use_self_speech_model2
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.use_speaker_emb_model1=use_speaker_emb_model1
|
| 154 |
+
|
| 155 |
+
def init_buffers(self, batch_size, device):
|
| 156 |
+
buffers = {}
|
| 157 |
+
|
| 158 |
+
buffers['model1_bufs'] = self.model1.init_buffers(batch_size, device)
|
| 159 |
+
|
| 160 |
+
buffers['model2_bufs'] = self.model2.init_buffers(batch_size, device)
|
| 161 |
+
|
| 162 |
+
buffers['istft_buf'] = torch.zeros(batch_size * self.n_srcs * self.nO,
|
| 163 |
+
self.synthesis_window.shape[0],
|
| 164 |
+
self.istft_lookback, device=device)
|
| 165 |
+
|
| 166 |
+
return buffers
|
| 167 |
+
|
| 168 |
+
# compute STFT
|
| 169 |
+
def extract_features(self, x):
|
| 170 |
+
"""
|
| 171 |
+
x: (B, M, T)
|
| 172 |
+
returns: (B, C*M, T, F)
|
| 173 |
+
"""
|
| 174 |
+
B, M, T = x.shape
|
| 175 |
+
|
| 176 |
+
x = x.reshape(B*M, T)
|
| 177 |
+
x = torch.stft(x, n_fft = self.nfft, hop_length = self.stft_chunk_size,
|
| 178 |
+
win_length = self.nfft, window=self.analysis_window.to(x.device),
|
| 179 |
+
center=False, normalized=False, return_complex=True)
|
| 180 |
+
|
| 181 |
+
x = torch.view_as_real(x) # [B*M, F, T, 2]
|
| 182 |
+
BM, _F, T, C = x.shape
|
| 183 |
+
|
| 184 |
+
x = x.reshape(B, M, _F, T, C) # [B, M, F, T, 2]
|
| 185 |
+
|
| 186 |
+
x = x.permute(0, 4, 1, 3, 2) # [B, 2, M. T, F]
|
| 187 |
+
|
| 188 |
+
x = x.reshape(B, C*M, T, _F)
|
| 189 |
+
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
def synthesis(self, x, input_state):
|
| 193 |
+
"""
|
| 194 |
+
x: (B, S, T, C*F)
|
| 195 |
+
returns: (B, S, t)
|
| 196 |
+
"""
|
| 197 |
+
istft_buf = input_state['istft_buf']
|
| 198 |
+
|
| 199 |
+
x = x.transpose(2, 3) # [B, S, CF, T]
|
| 200 |
+
|
| 201 |
+
B, S, CF, T = x.shape
|
| 202 |
+
X = x.reshape(B*S, CF, T)
|
| 203 |
+
X = X.reshape(B*S, 2, -1, T).permute(0, 2, 3, 1) # [BS, F, T, C]
|
| 204 |
+
X = X[..., 0] + 1j * X[..., 1]
|
| 205 |
+
|
| 206 |
+
x = torch.fft.irfft(X, dim=1) # [BS, iW, T]
|
| 207 |
+
x = x[:, -self.synthesis_window.shape[0]:] # [BS, oW, T]
|
| 208 |
+
|
| 209 |
+
# Apply synthesis window
|
| 210 |
+
x = x * self.synthesis_window.unsqueeze(0).unsqueeze(-1).to(x.device)
|
| 211 |
+
|
| 212 |
+
oW = self.synthesis_window.shape[0]
|
| 213 |
+
|
| 214 |
+
# Concatenate blocks from previous IFFTs
|
| 215 |
+
x = torch.cat([istft_buf, x], dim=-1)
|
| 216 |
+
istft_buf = x[..., -istft_buf.shape[1]:] # Update buffer
|
| 217 |
+
|
| 218 |
+
# Get full signal
|
| 219 |
+
x = F.fold(x, output_size=(self.stft_chunk_size * x.shape[-1] + (oW - self.stft_chunk_size), 1),
|
| 220 |
+
kernel_size=(oW, 1), stride=(self.stft_chunk_size, 1)) # [BS, 1, t]
|
| 221 |
+
|
| 222 |
+
x = x[:, :, -T * self.stft_chunk_size - self.stft_pad_size: - self.stft_pad_size]
|
| 223 |
+
x = x.reshape(B, S, -1) # [B, S, t]
|
| 224 |
+
|
| 225 |
+
input_state['istft_buf'] = istft_buf
|
| 226 |
+
|
| 227 |
+
return x, input_state
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def predict_model1(self, x, input_state, speaker_embedding, pad=True):
|
| 231 |
+
"""
|
| 232 |
+
B: batch
|
| 233 |
+
M: mic
|
| 234 |
+
t: time step (time-domain)
|
| 235 |
+
x: (B, M, t)
|
| 236 |
+
R: real or imaginary
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
mod = 0
|
| 240 |
+
if pad:
|
| 241 |
+
pad_size = (self.stft_back_pad, self.stft_pad_size)
|
| 242 |
+
x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size)
|
| 243 |
+
|
| 244 |
+
# Time-domain to TF-domain
|
| 245 |
+
x = self.extract_features(x) # [B, RM, T, F]
|
| 246 |
+
|
| 247 |
+
if speaker_embedding is not None:
|
| 248 |
+
speaker_embedding=speaker_embedding.unsqueeze(2)
|
| 249 |
+
|
| 250 |
+
conversation_emb, input_state['model1_bufs'] = self.model1(x, speaker_embedding, input_state['model1_bufs'], self.quantized)
|
| 251 |
+
|
| 252 |
+
return conversation_emb, input_state
|
| 253 |
+
|
| 254 |
+
def predict_model2(self, x, conversation_emb, input_state, pad=True):
|
| 255 |
+
"""
|
| 256 |
+
B: batch
|
| 257 |
+
M: mic
|
| 258 |
+
t: time step (time-domain)
|
| 259 |
+
x: (B, M, t)
|
| 260 |
+
R: real or imaginary
|
| 261 |
+
"""
|
| 262 |
+
mod = 0
|
| 263 |
+
if pad:
|
| 264 |
+
pad_size = (self.stft_back_pad, self.stft_pad_size)
|
| 265 |
+
x, mod = mod_pad(x, chunk_size=self.stft_chunk_size, pad=pad_size)
|
| 266 |
+
|
| 267 |
+
x = self.extract_features(x)
|
| 268 |
+
|
| 269 |
+
x, input_state['model2_bufs']=self.model2(x, conversation_emb, input_state['model2_bufs'], self.quantized)
|
| 270 |
+
|
| 271 |
+
# TF-domain to time-domain
|
| 272 |
+
x, next_state = self.synthesis(x, input_state) # [B, S * M, t]
|
| 273 |
+
|
| 274 |
+
if mod != 0:
|
| 275 |
+
x = x[:, :, :-mod]
|
| 276 |
+
|
| 277 |
+
return x, next_state
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def forward(self, inputs, input_state = None, pad=True):
|
| 281 |
+
x = inputs['mixture']
|
| 282 |
+
|
| 283 |
+
start_idx_input=inputs['start_idx']
|
| 284 |
+
end_idx_input=inputs['end_idx']
|
| 285 |
+
|
| 286 |
+
assert ((end_idx_input - start_idx_input) % self.stft_chunk_size) == 0
|
| 287 |
+
|
| 288 |
+
# Snap start and end to chunk
|
| 289 |
+
start_idx_input = (start_idx_input // self.stft_chunk_size) * self.stft_chunk_size
|
| 290 |
+
end_idx_input = (end_idx_input // self.stft_chunk_size) * self.stft_chunk_size
|
| 291 |
+
|
| 292 |
+
B, M, t=x.shape
|
| 293 |
+
|
| 294 |
+
audio_range=torch.tensor([start_idx_input, end_idx_input]).to(x.device)
|
| 295 |
+
audio_range = audio_range.unsqueeze(0).repeat(B, 1)
|
| 296 |
+
|
| 297 |
+
spk_embed = inputs['embed']
|
| 298 |
+
self_speech=None
|
| 299 |
+
|
| 300 |
+
if not self.use_speaker_emb_model1:
|
| 301 |
+
self_speech=inputs['self_speech']
|
| 302 |
+
|
| 303 |
+
combined_audio = torch.cat((x, self_speech), dim=1)
|
| 304 |
+
x=combined_audio
|
| 305 |
+
|
| 306 |
+
if input_state is None:
|
| 307 |
+
input_state = self.init_buffers(x.shape[0], x.device)
|
| 308 |
+
|
| 309 |
+
B, M, t = x.shape
|
| 310 |
+
|
| 311 |
+
# enter slow model
|
| 312 |
+
conversation_emb, input_state = self.predict_model1(x, input_state, spk_embed, pad=pad) # [B, F, T, C]
|
| 313 |
+
|
| 314 |
+
# slice conv embedding and corresponding audio
|
| 315 |
+
B, _F, T, C = conversation_emb.shape
|
| 316 |
+
conversation_emb = conversation_emb.permute(0, 1, 3, 2) # [B, F, C, T]
|
| 317 |
+
conversation_emb = torch.roll(conversation_emb, 1, dims=-1)
|
| 318 |
+
conversation_emb[..., 0] = 0
|
| 319 |
+
conversation_emb = conversation_emb.flatten(0,3).unsqueeze(1) # [*, 1]
|
| 320 |
+
multiplier = torch.tile(conversation_emb, (1, self.lstm_fold_chunk)) # [*, L]
|
| 321 |
+
multiplier = multiplier.reshape(B, _F, C, T, self.lstm_fold_chunk).flatten(3,4) # [B, F, C, T*L]
|
| 322 |
+
multiplier = multiplier.permute(0, 1, 3, 2) # [B, F, T*L, C]
|
| 323 |
+
|
| 324 |
+
slicing_length=end_idx_input-start_idx_input+self.stft_back_pad+self.stft_pad_size
|
| 325 |
+
|
| 326 |
+
padded_start=start_idx_input-self.stft_back_pad
|
| 327 |
+
padded_end=end_idx_input+self.stft_pad_size
|
| 328 |
+
|
| 329 |
+
pad_left=max(-padded_start, 0)
|
| 330 |
+
pad_right=max(padded_end-t, 0)
|
| 331 |
+
|
| 332 |
+
actual_start=max(padded_start, 0)
|
| 333 |
+
actual_end=min(padded_end, t)
|
| 334 |
+
|
| 335 |
+
if self.use_self_speech_model2:
|
| 336 |
+
sliced_x=x[:, :, actual_start:actual_end]
|
| 337 |
+
else:
|
| 338 |
+
x_no_self_speech=inputs["mixture"]
|
| 339 |
+
sliced_x=x_no_self_speech[:, :, actual_start:actual_end]
|
| 340 |
+
|
| 341 |
+
padding = (pad_left, pad_right, 0, 0, 0, 0)
|
| 342 |
+
|
| 343 |
+
sliced_x=F.pad(sliced_x, padding, "constant", 0)
|
| 344 |
+
|
| 345 |
+
converted_start_idx=start_idx_input//self.stft_chunk_size
|
| 346 |
+
converted_end_idx=end_idx_input//self.stft_chunk_size
|
| 347 |
+
|
| 348 |
+
sliced_emb=multiplier[:, :, converted_start_idx:converted_end_idx, :]
|
| 349 |
+
|
| 350 |
+
assert sliced_x.shape[2]==slicing_length
|
| 351 |
+
assert sliced_emb.shape[2]==(slicing_length-self.stft_back_pad-self.stft_pad_size)//self.stft_chunk_size
|
| 352 |
+
|
| 353 |
+
model2_output, input_state = self.predict_model2(sliced_x, sliced_emb, input_state, pad=False)
|
| 354 |
+
model2_output = model2_output.reshape(B, self.n_srcs, self.nO, model2_output.shape[-1])
|
| 355 |
+
|
| 356 |
+
return {'output': model2_output[:, 0], 'next_state': input_state, 'audio_range': audio_range}
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
pass
|
src/train_joint.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The main training script for training on synthetic data
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import multiprocessing
|
| 13 |
+
import time
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import src.utils as utils
|
| 17 |
+
from src.training.tain_val import train_epoch, test_epoch
|
| 18 |
+
import shutil
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
import wandb
|
| 22 |
+
|
| 23 |
+
VAL_SEED = 0
|
| 24 |
+
CURRENT_EPOCH = 0
|
| 25 |
+
|
| 26 |
+
def seed_from_epoch(seed):
|
| 27 |
+
global CURRENT_EPOCH
|
| 28 |
+
|
| 29 |
+
utils.seed_all(seed + CURRENT_EPOCH)
|
| 30 |
+
|
| 31 |
+
def print_metrics(metrics: list):
|
| 32 |
+
input_sisdr = np.array([x['input_si_sdr'] for x in metrics])
|
| 33 |
+
sisdr = np.array([x['si_sdr'] for x in metrics])
|
| 34 |
+
|
| 35 |
+
print("Average Input SI-SDR: {:03f}, Average Output SI-SDR: {:03f}, Average SI-SDRi: {:03f}".format(np.mean(input_sisdr), np.mean(sisdr), np.mean(sisdr - input_sisdr)))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def train(args: argparse.Namespace):
|
| 39 |
+
"""
|
| 40 |
+
Resolve the network to be trained
|
| 41 |
+
"""
|
| 42 |
+
# Fix random seeds
|
| 43 |
+
utils.seed_all(args.seed)
|
| 44 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
|
| 45 |
+
|
| 46 |
+
# Turn on deterministic algorithms if specified (Note: slower training).
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
if args.use_nondeterministic_cudnn:
|
| 49 |
+
torch.backends.cudnn.deterministic = False
|
| 50 |
+
else:
|
| 51 |
+
torch.backends.cudnn.deterministic = True
|
| 52 |
+
|
| 53 |
+
# Load experiment description
|
| 54 |
+
with open(args.config, 'rb') as f:
|
| 55 |
+
params = json.load(f)
|
| 56 |
+
|
| 57 |
+
# Initialize datasets
|
| 58 |
+
data_train = utils.import_attr(params['train_dataset'])(**params['train_data_args'], split='train')
|
| 59 |
+
data_val = utils.import_attr(params['val_dataset'])(**params['val_data_args'], split='val')
|
| 60 |
+
|
| 61 |
+
# Set up the device and workers
|
| 62 |
+
use_cuda = True
|
| 63 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
| 64 |
+
print("Using device {}".format('cuda' if use_cuda else 'cpu'))
|
| 65 |
+
|
| 66 |
+
# Set multiprocessing params
|
| 67 |
+
num_workers = min(multiprocessing.cpu_count(), params['num_workers'])
|
| 68 |
+
kwargs = {
|
| 69 |
+
'num_workers': num_workers,
|
| 70 |
+
'worker_init_fn': lambda x: seed_from_epoch(args.seed),
|
| 71 |
+
'pin_memory': False
|
| 72 |
+
} if use_cuda else {}
|
| 73 |
+
|
| 74 |
+
# Set up data loaders
|
| 75 |
+
train_loader = torch.utils.data.DataLoader(data_train,
|
| 76 |
+
batch_size=params['batch_size'],
|
| 77 |
+
shuffle=True,
|
| 78 |
+
**kwargs)
|
| 79 |
+
|
| 80 |
+
kwargs['worker_init_fn'] = lambda x: utils.seed_all(VAL_SEED)
|
| 81 |
+
test_loader = torch.utils.data.DataLoader(data_val,
|
| 82 |
+
batch_size=params['eval_batch_size'],
|
| 83 |
+
**kwargs)
|
| 84 |
+
|
| 85 |
+
# Initialize HL module
|
| 86 |
+
hl_module = utils.import_attr(params['pl_module'])(**params['pl_module_args'])
|
| 87 |
+
hl_module.model.to(device)
|
| 88 |
+
|
| 89 |
+
# Get run name from run dir
|
| 90 |
+
run_name = os.path.basename(args.run_dir.rstrip('/'))
|
| 91 |
+
checkpoints_dir = os.path.join(args.run_dir, 'checkpoints')
|
| 92 |
+
|
| 93 |
+
# Set up checkpoints
|
| 94 |
+
if not os.path.exists(checkpoints_dir):
|
| 95 |
+
os.makedirs(checkpoints_dir)
|
| 96 |
+
|
| 97 |
+
# Copy json
|
| 98 |
+
shutil.copyfile(args.config, os.path.join(args.run_dir, 'config.json'))
|
| 99 |
+
|
| 100 |
+
# Check if a model state path exists for this model, if it does, load it
|
| 101 |
+
best_path = os.path.join(checkpoints_dir, 'best.pt')
|
| 102 |
+
state_path = os.path.join(checkpoints_dir, 'last.pt')
|
| 103 |
+
if args.best and os.path.exists(best_path):
|
| 104 |
+
print("load best state path .....")
|
| 105 |
+
hl_module.load_state(best_path)
|
| 106 |
+
|
| 107 |
+
elif os.path.exists(state_path):
|
| 108 |
+
print("load state path .....")
|
| 109 |
+
hl_module.load_state(state_path)
|
| 110 |
+
|
| 111 |
+
start_epoch = hl_module.epoch
|
| 112 |
+
|
| 113 |
+
if "project_name" in params.keys():
|
| 114 |
+
project_name = params["project_name"]
|
| 115 |
+
else:
|
| 116 |
+
project_name = "AcousticBubble"
|
| 117 |
+
# Initialize wandb
|
| 118 |
+
# print(project_name)
|
| 119 |
+
wandb_run = wandb.init(
|
| 120 |
+
project=project_name,
|
| 121 |
+
name=run_name,
|
| 122 |
+
notes='Example of a note',
|
| 123 |
+
tags=['speech', 'audio', 'embedded-systems']
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Training loop
|
| 127 |
+
try:
|
| 128 |
+
# Go over remaining epochs
|
| 129 |
+
for epoch in range(start_epoch, params['epochs']):
|
| 130 |
+
global CURRENT_EPOCH, VAL_SEED
|
| 131 |
+
CURRENT_EPOCH = epoch
|
| 132 |
+
seed_from_epoch(args.seed)
|
| 133 |
+
|
| 134 |
+
hl_module.on_epoch_start()
|
| 135 |
+
|
| 136 |
+
current_lr = hl_module.get_current_lr()
|
| 137 |
+
print("CURRENT learning rate: {:0.08f}".format(current_lr))
|
| 138 |
+
|
| 139 |
+
print("[TRAINING]")
|
| 140 |
+
|
| 141 |
+
# Run testing step
|
| 142 |
+
|
| 143 |
+
t1 = time.time()
|
| 144 |
+
train_loss = train_epoch(hl_module, train_loader, device)
|
| 145 |
+
t2 = time.time()
|
| 146 |
+
print(f"Train epoch time: {t2 - t1:02f}s")
|
| 147 |
+
|
| 148 |
+
print("\nTrain set: Average Loss: {:.4f}\n".format(train_loss))
|
| 149 |
+
|
| 150 |
+
print()
|
| 151 |
+
if np.isnan(train_loss):
|
| 152 |
+
raise ValueError("Got NAN in training")
|
| 153 |
+
utils.seed_all(VAL_SEED)
|
| 154 |
+
|
| 155 |
+
# Run testing step
|
| 156 |
+
|
| 157 |
+
print("[TESTING]")
|
| 158 |
+
|
| 159 |
+
test_loss = test_epoch(hl_module, test_loader, device)
|
| 160 |
+
|
| 161 |
+
print("\nTest set: Average Loss: {:.4f}\n".format(test_loss))
|
| 162 |
+
|
| 163 |
+
hl_module.on_epoch_end(best_path, wandb_run)
|
| 164 |
+
hl_module.dump_state(state_path)
|
| 165 |
+
|
| 166 |
+
print()
|
| 167 |
+
print("=" * 25, "FINISHED EPOCH", epoch, "=" * 25)
|
| 168 |
+
print()
|
| 169 |
+
|
| 170 |
+
except KeyboardInterrupt:
|
| 171 |
+
print("Interrupted")
|
| 172 |
+
except Exception as _:
|
| 173 |
+
import traceback
|
| 174 |
+
traceback.print_exc()
|
| 175 |
+
|
| 176 |
+
if __name__ == '__main__':
|
| 177 |
+
parser = argparse.ArgumentParser()
|
| 178 |
+
# Experiment Params
|
| 179 |
+
parser.add_argument('--config', type=str,
|
| 180 |
+
help='Path to experiment config')
|
| 181 |
+
|
| 182 |
+
parser.add_argument('--run_dir', type=str,
|
| 183 |
+
help='Path to experiment directory')
|
| 184 |
+
|
| 185 |
+
parser.add_argument('--best', action='store_true',
|
| 186 |
+
help="load from best checkpoint instead of last checkpoint")
|
| 187 |
+
|
| 188 |
+
# Randomization Params
|
| 189 |
+
parser.add_argument('--seed', type=int, default=10,
|
| 190 |
+
help='Random seed for reproducibility')
|
| 191 |
+
parser.add_argument('--use_nondeterministic_cudnn',
|
| 192 |
+
action='store_true',
|
| 193 |
+
help="If using cuda, chooses whether or not to use \
|
| 194 |
+
non-deterministic cudDNN algorithms. Training will be\
|
| 195 |
+
faster, but the final results may differ slighty.")
|
| 196 |
+
|
| 197 |
+
# wandb params
|
| 198 |
+
parser.add_argument('--project_name',
|
| 199 |
+
type=str,
|
| 200 |
+
default='AcousticBubble',
|
| 201 |
+
help='Project name that shows up on wandb')
|
| 202 |
+
train(parser.parse_args())
|
src/training/tain_val.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The main training script for training on synthetic data
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
import os
|
| 9 |
+
import tqdm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def to_device(batch, device):
|
| 13 |
+
if type(batch) == torch.Tensor:
|
| 14 |
+
return batch.to(device)
|
| 15 |
+
elif type(batch) == dict:
|
| 16 |
+
for k in batch:
|
| 17 |
+
batch[k] = to_device(batch[k], device)
|
| 18 |
+
return batch
|
| 19 |
+
elif type(batch) in [list, tuple]:
|
| 20 |
+
batch = [to_device(x, device) for x in batch]
|
| 21 |
+
return batch
|
| 22 |
+
else:
|
| 23 |
+
return batch
|
| 24 |
+
|
| 25 |
+
def test_epoch(hl_module, test_loader, device) -> float:
|
| 26 |
+
"""
|
| 27 |
+
Evaluate the network.
|
| 28 |
+
"""
|
| 29 |
+
hl_module.eval()
|
| 30 |
+
|
| 31 |
+
test_loss = 0
|
| 32 |
+
num_elements = 0
|
| 33 |
+
|
| 34 |
+
num_batches = len(test_loader)
|
| 35 |
+
pbar = tqdm.tqdm(total=num_batches)
|
| 36 |
+
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
for batch_idx, batch in enumerate(test_loader):
|
| 39 |
+
batch = to_device(batch, device)
|
| 40 |
+
|
| 41 |
+
loss, B = hl_module.validation_step(batch, batch_idx)
|
| 42 |
+
#print(loss.item(), B)
|
| 43 |
+
test_loss += (loss.item() * B)
|
| 44 |
+
num_elements += B
|
| 45 |
+
|
| 46 |
+
pbar.set_postfix(loss='%.05f'%(loss.item()) )
|
| 47 |
+
pbar.update()
|
| 48 |
+
|
| 49 |
+
return test_loss / num_elements
|
| 50 |
+
|
| 51 |
+
def train_epoch(hl_module, train_loader, device) -> float:
|
| 52 |
+
"""
|
| 53 |
+
Train a single epoch.
|
| 54 |
+
"""
|
| 55 |
+
# Set the model to training.
|
| 56 |
+
hl_module.train()
|
| 57 |
+
|
| 58 |
+
# Training loop
|
| 59 |
+
train_loss = 0
|
| 60 |
+
num_elements = 0
|
| 61 |
+
|
| 62 |
+
num_batches = len(train_loader)
|
| 63 |
+
pbar = tqdm.tqdm(total=num_batches)
|
| 64 |
+
|
| 65 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 66 |
+
batch = to_device(batch, device)
|
| 67 |
+
|
| 68 |
+
# Reset grad
|
| 69 |
+
hl_module.reset_grad()
|
| 70 |
+
|
| 71 |
+
# Forward pass
|
| 72 |
+
loss, B = hl_module.training_step(batch, batch_idx)
|
| 73 |
+
|
| 74 |
+
# Backpropagation
|
| 75 |
+
loss.backward(retain_graph=False)
|
| 76 |
+
hl_module.backprop()
|
| 77 |
+
|
| 78 |
+
# Save losses
|
| 79 |
+
loss = loss.detach()
|
| 80 |
+
train_loss += (loss.item() * B)
|
| 81 |
+
num_elements += B
|
| 82 |
+
# if batch_idx % 20 == 0:
|
| 83 |
+
# print(loss.item(), B)
|
| 84 |
+
# print('{}/{}'.format(batch_idx, num_batches))
|
| 85 |
+
pbar.set_postfix(loss='%.05f'%(loss.item()) )
|
| 86 |
+
pbar.update()
|
| 87 |
+
|
| 88 |
+
return train_loss / num_elements
|
src/utils.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import importlib
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import librosa
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
import math
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PositionalEncoding(nn.Module):
|
| 15 |
+
"""This class implements the absolute sinusoidal positional encoding function.
|
| 16 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
| 17 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
| 18 |
+
Arguments
|
| 19 |
+
---------
|
| 20 |
+
input_size: int
|
| 21 |
+
Embedding dimension.
|
| 22 |
+
max_len : int, optional
|
| 23 |
+
Max length of the input sequences (default 2500).
|
| 24 |
+
Example
|
| 25 |
+
-------
|
| 26 |
+
>>> a = torch.rand((8, 120, 512))
|
| 27 |
+
>>> enc = PositionalEncoding(input_size=a.shape[-1])
|
| 28 |
+
>>> b = enc(a)
|
| 29 |
+
>>> b.shape
|
| 30 |
+
torch.Size([1, 120, 512])
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, input_size, max_len=2500):
|
| 34 |
+
super().__init__()
|
| 35 |
+
if input_size % 2 != 0:
|
| 36 |
+
raise ValueError(f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})")
|
| 37 |
+
self.max_len = max_len
|
| 38 |
+
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
|
| 39 |
+
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
|
| 40 |
+
denominator = torch.exp(torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size))
|
| 41 |
+
|
| 42 |
+
pe[:, 0::2] = torch.sin(positions * denominator)
|
| 43 |
+
pe[:, 1::2] = torch.cos(positions * denominator)
|
| 44 |
+
pe = pe.unsqueeze(0)
|
| 45 |
+
self.register_buffer("pe", pe)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
"""
|
| 49 |
+
Arguments
|
| 50 |
+
---------
|
| 51 |
+
x : tensor
|
| 52 |
+
Input feature shape (batch, time, fea)
|
| 53 |
+
"""
|
| 54 |
+
return self.pe[:, : x.size(1)].clone().detach()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def count_parameters(model):
|
| 58 |
+
"""
|
| 59 |
+
Count the number of parameters in a PyTorch model.
|
| 60 |
+
|
| 61 |
+
Parameters:
|
| 62 |
+
model (torch.nn.Module): The PyTorch model.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
int: Number of parameters in the model.
|
| 66 |
+
"""
|
| 67 |
+
N_param = sum(p.numel() for p in model.parameters())
|
| 68 |
+
print(f"Model params number {N_param/1e6} M")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def import_attr(import_path):
|
| 72 |
+
module, attr = import_path.rsplit(".", 1)
|
| 73 |
+
return getattr(importlib.import_module(module), attr)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Params:
|
| 77 |
+
"""Class that loads hyperparameters from a json file.
|
| 78 |
+
Example:
|
| 79 |
+
```
|
| 80 |
+
params = Params(json_path)
|
| 81 |
+
print(params.learning_rate)
|
| 82 |
+
params.learning_rate = 0.5 # change the value of learning_rate in params
|
| 83 |
+
```
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, json_path):
|
| 87 |
+
with open(json_path) as f:
|
| 88 |
+
params = json.load(f)
|
| 89 |
+
self.__dict__.update(params)
|
| 90 |
+
|
| 91 |
+
def save(self, json_path):
|
| 92 |
+
with open(json_path, "w") as f:
|
| 93 |
+
json.dump(self.__dict__, f, indent=4)
|
| 94 |
+
|
| 95 |
+
def update(self, json_path):
|
| 96 |
+
"""Loads parameters from json file"""
|
| 97 |
+
with open(json_path) as f:
|
| 98 |
+
params = json.load(f)
|
| 99 |
+
self.__dict__.update(params)
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def dict(self):
|
| 103 |
+
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
|
| 104 |
+
return self.__dict__
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_net_torch(expriment_config, return_params=False):
|
| 108 |
+
params = Params(expriment_config)
|
| 109 |
+
params.pl_module_args["slow_model_ckpt"] = None
|
| 110 |
+
params.pl_module_args["use_dp"] = False
|
| 111 |
+
params.pl_module_args["prev_ckpt"] = None
|
| 112 |
+
pl_module = import_attr(params.pl_module)(**params.pl_module_args)
|
| 113 |
+
|
| 114 |
+
with open(expriment_config) as f:
|
| 115 |
+
params = json.load(f)
|
| 116 |
+
|
| 117 |
+
if return_params:
|
| 118 |
+
return pl_module, params
|
| 119 |
+
else:
|
| 120 |
+
return pl_module
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def load_net(expriment_config, return_params=False):
|
| 124 |
+
params = Params(expriment_config)
|
| 125 |
+
params.pl_module_args["use_dp"] = False
|
| 126 |
+
pl_module = import_attr(params.pl_module)(**params.pl_module_args)
|
| 127 |
+
|
| 128 |
+
with open(expriment_config) as f:
|
| 129 |
+
params = json.load(f)
|
| 130 |
+
|
| 131 |
+
if return_params:
|
| 132 |
+
return pl_module, params
|
| 133 |
+
else:
|
| 134 |
+
return pl_module
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_pretrained(run_dir, return_params=False, map_location="cpu", use_last=False):
|
| 138 |
+
config_path = os.path.join(run_dir, "config.json")
|
| 139 |
+
|
| 140 |
+
pl_module, params = load_net(config_path, return_params=True)
|
| 141 |
+
|
| 142 |
+
# Get all "best" checkpoints
|
| 143 |
+
if use_last:
|
| 144 |
+
name = "last.pt"
|
| 145 |
+
else:
|
| 146 |
+
name = "best.pt"
|
| 147 |
+
ckpt_path = os.path.join(run_dir, f"checkpoints/{name}")
|
| 148 |
+
|
| 149 |
+
if not os.path.exists(ckpt_path):
|
| 150 |
+
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
|
| 151 |
+
|
| 152 |
+
print("Loading checkpoint from", ckpt_path)
|
| 153 |
+
|
| 154 |
+
# Load checkpoint
|
| 155 |
+
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
|
| 156 |
+
pl_module.load_state(ckpt_path, map_location)
|
| 157 |
+
print("Loaded module at epoch", pl_module.epoch)
|
| 158 |
+
|
| 159 |
+
if return_params:
|
| 160 |
+
return pl_module, params
|
| 161 |
+
else:
|
| 162 |
+
return pl_module
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_pretrained_with_last(run_dir, return_params=False, map_location="cpu", use_last=False):
|
| 166 |
+
config_path = os.path.join(run_dir, "config.json")
|
| 167 |
+
|
| 168 |
+
pl_module, params = load_net(config_path, return_params=True)
|
| 169 |
+
|
| 170 |
+
# Get all "best" checkpoints
|
| 171 |
+
if use_last:
|
| 172 |
+
name = "last.pt"
|
| 173 |
+
else:
|
| 174 |
+
name = "best.pt"
|
| 175 |
+
ckpt_path = os.path.join(run_dir, f"checkpoints/{name}")
|
| 176 |
+
|
| 177 |
+
if not os.path.exists(ckpt_path):
|
| 178 |
+
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
|
| 179 |
+
|
| 180 |
+
print("Loading checkpoint from", ckpt_path)
|
| 181 |
+
|
| 182 |
+
# Load checkpoint
|
| 183 |
+
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
|
| 184 |
+
pl_module.load_state(ckpt_path, map_location)
|
| 185 |
+
print("Loaded module at epoch", pl_module.epoch)
|
| 186 |
+
|
| 187 |
+
if return_params:
|
| 188 |
+
return pl_module, params
|
| 189 |
+
else:
|
| 190 |
+
return pl_module
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def load_pretrained2(run_dir, return_params=False, map_location="cpu"):
|
| 194 |
+
config_path = os.path.join(run_dir, "config.json")
|
| 195 |
+
pl_module, params = load_net(config_path, return_params=True)
|
| 196 |
+
|
| 197 |
+
ckpt_path = os.path.join(run_dir, "checkpoints", "best.pt")
|
| 198 |
+
print("Loading checkpoint from", ckpt_path)
|
| 199 |
+
|
| 200 |
+
# Load checkpoint
|
| 201 |
+
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
|
| 202 |
+
pl_module.load_state(ckpt_path)
|
| 203 |
+
|
| 204 |
+
if return_params:
|
| 205 |
+
return pl_module, params
|
| 206 |
+
else:
|
| 207 |
+
return pl_module
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_torch_pretrained(run_dir, return_params=False, map_location="cpu", model_epoch="best"):
|
| 211 |
+
config_path = os.path.join(run_dir, "config.json")
|
| 212 |
+
|
| 213 |
+
print(config_path)
|
| 214 |
+
pl_module, params = load_net_torch(config_path, return_params=True)
|
| 215 |
+
|
| 216 |
+
# Get all "best" checkpoints
|
| 217 |
+
ckpt_path = os.path.join(run_dir, f"checkpoints/{model_epoch}.pt")
|
| 218 |
+
|
| 219 |
+
if not os.path.exists(ckpt_path):
|
| 220 |
+
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
|
| 221 |
+
|
| 222 |
+
print("Loading checkpoint from", ckpt_path)
|
| 223 |
+
|
| 224 |
+
# Load checkpoint
|
| 225 |
+
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
|
| 226 |
+
pl_module.load_state(ckpt_path, map_location)
|
| 227 |
+
print("Loaded module at epoch", pl_module.epoch)
|
| 228 |
+
|
| 229 |
+
if return_params:
|
| 230 |
+
return pl_module, params
|
| 231 |
+
else:
|
| 232 |
+
return pl_module
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def read_audio_file(file_path, sr):
|
| 236 |
+
"""
|
| 237 |
+
Reads audio file to system memory.
|
| 238 |
+
"""
|
| 239 |
+
return librosa.core.load(file_path, mono=False, sr=sr)[0]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def read_audio_file_torch(file_path, downsample=1, input_mean=False):
|
| 243 |
+
waveform, sample_rate = torchaudio.load(file_path)
|
| 244 |
+
if downsample > 1:
|
| 245 |
+
waveform = torchaudio.functional.resample(waveform, sample_rate, sample_rate // downsample)
|
| 246 |
+
|
| 247 |
+
if waveform.shape[0] > 1 and input_mean == True:
|
| 248 |
+
waveform = torch.mean(waveform, dim=0)
|
| 249 |
+
waveform = waveform.unsqueeze(0)
|
| 250 |
+
|
| 251 |
+
elif waveform.shape[0] > 1 and input_mean == "L":
|
| 252 |
+
waveform = waveform[0:1, ...]
|
| 253 |
+
|
| 254 |
+
elif waveform.shape[0] > 1 and input_mean == "R":
|
| 255 |
+
waveform = waveform[1:2, ...]
|
| 256 |
+
|
| 257 |
+
return waveform
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def write_audio_file(file_path, data, sr, subtype="PCM_16"):
|
| 261 |
+
"""
|
| 262 |
+
Writes audio file to system memory.
|
| 263 |
+
@param file_path: Path of the file to write to
|
| 264 |
+
@param data: Audio signal to write (n_channels x n_samples)
|
| 265 |
+
@param sr: Sampling rate
|
| 266 |
+
"""
|
| 267 |
+
sf.write(file_path, data.T, sr, subtype)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def read_json(path):
|
| 271 |
+
with open(path, "rb") as f:
|
| 272 |
+
return json.load(f)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
import random
|
| 276 |
+
import numpy as np
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def seed_all(seed):
|
| 280 |
+
random.seed(seed)
|
| 281 |
+
np.random.seed(seed)
|
| 282 |
+
torch.manual_seed(seed)
|
| 283 |
+
|
| 284 |
+
if torch.cuda.is_available():
|
| 285 |
+
torch.cuda.manual_seed(seed)
|