Spaces:
Runtime error
Runtime error
Base code
Browse files- app.py +62 -0
- default_config.json +60 -0
- requirements.txt +9 -0
- src/__init__.py +0 -0
- src/helpers/__init__.py +0 -0
- src/helpers/utils.py +205 -0
- src/training/__init__.py +0 -0
- src/training/dcc_tf.py +486 -0
- src/training/eval.py +214 -0
- src/training/synthetic_dataset.py +168 -0
- src/training/train.py +311 -0
app.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import wget
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
from src.helpers import utils
|
| 10 |
+
from src.training.dcc_tf import Net as Waveformer
|
| 11 |
+
|
| 12 |
+
TARGETS = [
|
| 13 |
+
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
| 14 |
+
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
|
| 15 |
+
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
|
| 16 |
+
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
|
| 17 |
+
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
|
| 18 |
+
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
|
| 19 |
+
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
|
| 20 |
+
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
|
| 21 |
+
"Trumpet", "Violin_or_fiddle", "Writing"
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
if not os.path.exists('default_config.json'):
|
| 25 |
+
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
|
| 26 |
+
print("Downloading model configuration from %s:" % config_url)
|
| 27 |
+
wget.download(config_url)
|
| 28 |
+
|
| 29 |
+
if not os.path.exists('default_ckpt.pt'):
|
| 30 |
+
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
|
| 31 |
+
print("\nDownloading the checkpoint from %s:" % ckpt_url)
|
| 32 |
+
wget.download(ckpt_url)
|
| 33 |
+
|
| 34 |
+
# Instantiate model
|
| 35 |
+
params = utils.Params('default_config.json')
|
| 36 |
+
model = Waveformer(**params.model_params)
|
| 37 |
+
utils.load_checkpoint('default_ckpt.pt', model)
|
| 38 |
+
model.eval()
|
| 39 |
+
|
| 40 |
+
def waveformer(audio, label_choices):
|
| 41 |
+
# Read input audio
|
| 42 |
+
fs, mixture = audio
|
| 43 |
+
if fs != 44100:
|
| 44 |
+
raise ValueError(fs)
|
| 45 |
+
mixture = torch.from_numpy(mixture).unsqueeze(0)
|
| 46 |
+
|
| 47 |
+
# Construct the query vector
|
| 48 |
+
if len(label_choices) == 0:
|
| 49 |
+
raise ValueError(label_choices)
|
| 50 |
+
query = torch.zeros(1, len(TARGETS))
|
| 51 |
+
for t in label_choices:
|
| 52 |
+
query[0, TARGETS.index(t)] = 1.
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
output = model(mixture, query)
|
| 56 |
+
|
| 57 |
+
return fs, output.squeeze(0).numpy()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
label_checkbox = gr.CheckboxGroup(choices=TARGETS)
|
| 61 |
+
demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio")
|
| 62 |
+
demo.launch()
|
default_config.json
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": "src.training.dcc_tf",
|
| 3 |
+
"model_params":
|
| 4 |
+
{
|
| 5 |
+
"label_len": 41,
|
| 6 |
+
"L": 32,
|
| 7 |
+
"enc_dim": 512,
|
| 8 |
+
"num_enc_layers": 10,
|
| 9 |
+
"dec_dim": 256,
|
| 10 |
+
"num_dec_layers": 1,
|
| 11 |
+
"dec_buf_len": 13,
|
| 12 |
+
"dec_chunk_size": 13,
|
| 13 |
+
"out_buf_len": 4,
|
| 14 |
+
"use_pos_enc": "true"
|
| 15 |
+
},
|
| 16 |
+
"train_data":
|
| 17 |
+
{
|
| 18 |
+
"input_dir": "data/FSDSoundScapes",
|
| 19 |
+
"dset": "train",
|
| 20 |
+
"sr": 44100,
|
| 21 |
+
"resample_rate": null,
|
| 22 |
+
"max_num_targets":3
|
| 23 |
+
},
|
| 24 |
+
"val_data":
|
| 25 |
+
{
|
| 26 |
+
"input_dir": "data/FSDSoundScapes",
|
| 27 |
+
"dset": "val",
|
| 28 |
+
"sr": 44100,
|
| 29 |
+
"resample_rate": null,
|
| 30 |
+
"max_num_targets":3
|
| 31 |
+
},
|
| 32 |
+
"test_data":
|
| 33 |
+
{
|
| 34 |
+
"input_dir": "data/FSDSoundScapes",
|
| 35 |
+
"dset": "test",
|
| 36 |
+
"sr": 44100,
|
| 37 |
+
"resample_rate": null,
|
| 38 |
+
"max_num_targets":3
|
| 39 |
+
},
|
| 40 |
+
"optim":
|
| 41 |
+
{
|
| 42 |
+
"lr": 5e-4,
|
| 43 |
+
"weight_decay": 0.0
|
| 44 |
+
},
|
| 45 |
+
"lr_sched":
|
| 46 |
+
{
|
| 47 |
+
"mode": "max",
|
| 48 |
+
"factor": 0.1,
|
| 49 |
+
"patience": 5,
|
| 50 |
+
"min_lr": 5e-6,
|
| 51 |
+
"threshold": 0.1,
|
| 52 |
+
"threshold_mode": "abs"
|
| 53 |
+
},
|
| 54 |
+
"base_metric": "scale_invariant_signal_noise_ratio",
|
| 55 |
+
"fix_lr_epochs": 50,
|
| 56 |
+
"epochs": 150,
|
| 57 |
+
"batch_size": 16,
|
| 58 |
+
"eval_batch_size": 64,
|
| 59 |
+
"n_workers": 16
|
| 60 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Requirements
|
| 2 |
+
librosa
|
| 3 |
+
torch
|
| 4 |
+
torchaudio
|
| 5 |
+
soundfile
|
| 6 |
+
numpy
|
| 7 |
+
speechbrain
|
| 8 |
+
wget
|
| 9 |
+
|
src/__init__.py
ADDED
|
File without changes
|
src/helpers/__init__.py
ADDED
|
File without changes
|
src/helpers/utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A collection of useful helper functions"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from torchmetrics.functional import(
|
| 11 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
| 12 |
+
signal_noise_ratio as snr,
|
| 13 |
+
signal_distortion_ratio as sdr,
|
| 14 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
class Params():
|
| 18 |
+
"""Class that loads hyperparameters from a json file.
|
| 19 |
+
Example:
|
| 20 |
+
```
|
| 21 |
+
params = Params(json_path)
|
| 22 |
+
print(params.learning_rate)
|
| 23 |
+
params.learning_rate = 0.5 # change the value of learning_rate in params
|
| 24 |
+
```
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, json_path):
|
| 28 |
+
with open(json_path) as f:
|
| 29 |
+
params = json.load(f)
|
| 30 |
+
self.__dict__.update(params)
|
| 31 |
+
|
| 32 |
+
def save(self, json_path):
|
| 33 |
+
with open(json_path, 'w') as f:
|
| 34 |
+
json.dump(self.__dict__, f, indent=4)
|
| 35 |
+
|
| 36 |
+
def update(self, json_path):
|
| 37 |
+
"""Loads parameters from json file"""
|
| 38 |
+
with open(json_path) as f:
|
| 39 |
+
params = json.load(f)
|
| 40 |
+
self.__dict__.update(params)
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def dict(self):
|
| 44 |
+
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
|
| 45 |
+
return self.__dict__
|
| 46 |
+
|
| 47 |
+
def save_graph(train_metrics, test_metrics, save_dir):
|
| 48 |
+
metrics = [snr, si_snr]
|
| 49 |
+
results = {'train_loss': train_metrics['loss'],
|
| 50 |
+
'test_loss' : test_metrics['loss']}
|
| 51 |
+
|
| 52 |
+
for m_fn in metrics:
|
| 53 |
+
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
|
| 54 |
+
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
|
| 55 |
+
|
| 56 |
+
results_pd = pd.DataFrame(results)
|
| 57 |
+
|
| 58 |
+
results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
|
| 59 |
+
|
| 60 |
+
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
|
| 61 |
+
axs=[]
|
| 62 |
+
for i in temp_ax:
|
| 63 |
+
for j in i:
|
| 64 |
+
axs.append(j)
|
| 65 |
+
|
| 66 |
+
x = range(len(train_metrics['loss']))
|
| 67 |
+
axs[0].plot(x, train_metrics['loss'], label='train')
|
| 68 |
+
axs[0].plot(x, test_metrics['loss'], label='test')
|
| 69 |
+
axs[0].set(ylabel='Loss')
|
| 70 |
+
axs[0].set(xlabel='Epoch')
|
| 71 |
+
axs[0].set_title('loss',fontweight='bold')
|
| 72 |
+
axs[0].legend()
|
| 73 |
+
|
| 74 |
+
for i in range(len(metrics)):
|
| 75 |
+
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
|
| 76 |
+
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
|
| 77 |
+
axs[i+1].set(xlabel='Epoch')
|
| 78 |
+
axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
|
| 79 |
+
axs[i+1].legend()
|
| 80 |
+
|
| 81 |
+
plt.tight_layout()
|
| 82 |
+
plt.savefig(os.path.join(save_dir, 'results.png'))
|
| 83 |
+
plt.close(fig)
|
| 84 |
+
|
| 85 |
+
def set_logger(log_path):
|
| 86 |
+
"""Set the logger to log info in terminal and file `log_path`.
|
| 87 |
+
In general, it is useful to have a logger so that every output to the terminal is saved
|
| 88 |
+
in a permanent file. Here we save it to `model_dir/train.log`.
|
| 89 |
+
Example:
|
| 90 |
+
```
|
| 91 |
+
logging.info("Starting training...")
|
| 92 |
+
```
|
| 93 |
+
Args:
|
| 94 |
+
log_path: (string) where to log
|
| 95 |
+
"""
|
| 96 |
+
logger = logging.getLogger()
|
| 97 |
+
logger.setLevel(logging.INFO)
|
| 98 |
+
logger.handlers.clear()
|
| 99 |
+
|
| 100 |
+
# Logging to a file
|
| 101 |
+
file_handler = logging.FileHandler(log_path)
|
| 102 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
| 103 |
+
logger.addHandler(file_handler)
|
| 104 |
+
|
| 105 |
+
# Logging to console
|
| 106 |
+
stream_handler = logging.StreamHandler()
|
| 107 |
+
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 108 |
+
logger.addHandler(stream_handler)
|
| 109 |
+
|
| 110 |
+
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
|
| 111 |
+
"""Loads model parameters (state_dict) from file_path.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
checkpoint: (string) filename which needs to be loaded
|
| 115 |
+
model: (torch.nn.Module) model for which the parameters are loaded
|
| 116 |
+
data_parallel: (bool) if the model is a data parallel model
|
| 117 |
+
"""
|
| 118 |
+
if not os.path.exists(checkpoint):
|
| 119 |
+
raise("File doesn't exist {}".format(checkpoint))
|
| 120 |
+
|
| 121 |
+
state_dict = torch.load(checkpoint)
|
| 122 |
+
|
| 123 |
+
if data_parallel:
|
| 124 |
+
state_dict['model_state_dict'] = {
|
| 125 |
+
'module.' + k: state_dict['model_state_dict'][k]
|
| 126 |
+
for k in state_dict['model_state_dict'].keys()}
|
| 127 |
+
model.load_state_dict(state_dict['model_state_dict'])
|
| 128 |
+
|
| 129 |
+
if optim is not None:
|
| 130 |
+
optim.load_state_dict(state_dict['optim_state_dict'])
|
| 131 |
+
|
| 132 |
+
if lr_sched is not None:
|
| 133 |
+
lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
|
| 134 |
+
|
| 135 |
+
return state_dict['epoch'], state_dict['train_metrics'], \
|
| 136 |
+
state_dict['val_metrics']
|
| 137 |
+
|
| 138 |
+
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
|
| 139 |
+
train_metrics=None, val_metrics=None, data_parallel=False):
|
| 140 |
+
"""Saves model parameters (state_dict) to file_path.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
checkpoint: (string) filename which needs to be loaded
|
| 144 |
+
model: (torch.nn.Module) model for which the parameters are loaded
|
| 145 |
+
data_parallel: (bool) if the model is a data parallel model
|
| 146 |
+
"""
|
| 147 |
+
if os.path.exists(checkpoint):
|
| 148 |
+
raise("File already exists {}".format(checkpoint))
|
| 149 |
+
|
| 150 |
+
model_state_dict = model.state_dict()
|
| 151 |
+
if data_parallel:
|
| 152 |
+
model_state_dict = {
|
| 153 |
+
k.partition('module.')[2]:
|
| 154 |
+
model_state_dict[k] for k in model_state_dict.keys()}
|
| 155 |
+
|
| 156 |
+
optim_state_dict = None if not optim else optim.state_dict()
|
| 157 |
+
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
|
| 158 |
+
|
| 159 |
+
state_dict = {
|
| 160 |
+
'epoch': epoch,
|
| 161 |
+
'model_state_dict': model_state_dict,
|
| 162 |
+
'optim_state_dict': optim_state_dict,
|
| 163 |
+
'lr_sched_state_dict': lr_sched_state_dict,
|
| 164 |
+
'train_metrics': train_metrics,
|
| 165 |
+
'val_metrics': val_metrics
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
torch.save(state_dict, checkpoint)
|
| 169 |
+
|
| 170 |
+
def model_size(model):
|
| 171 |
+
"""
|
| 172 |
+
Returns size of the `model` in millions of parameters.
|
| 173 |
+
"""
|
| 174 |
+
num_train_params = sum(
|
| 175 |
+
p.numel() for p in model.parameters() if p.requires_grad)
|
| 176 |
+
return num_train_params / 1e6
|
| 177 |
+
|
| 178 |
+
def run_time(model, inputs, profiling=False):
|
| 179 |
+
"""
|
| 180 |
+
Returns runtime of a model in ms.
|
| 181 |
+
"""
|
| 182 |
+
# Warmup
|
| 183 |
+
for _ in range(100):
|
| 184 |
+
output = model(*inputs)
|
| 185 |
+
|
| 186 |
+
with profile(activities=[ProfilerActivity.CPU],
|
| 187 |
+
record_shapes=True) as prof:
|
| 188 |
+
with record_function("model_inference"):
|
| 189 |
+
output = model(*inputs)
|
| 190 |
+
|
| 191 |
+
# Print profiling results
|
| 192 |
+
if profiling:
|
| 193 |
+
print(prof.key_averages().table(sort_by="self_cpu_time_total",
|
| 194 |
+
row_limit=20))
|
| 195 |
+
|
| 196 |
+
# Return runtime in ms
|
| 197 |
+
return prof.profiler.self_cpu_time_total / 1000
|
| 198 |
+
|
| 199 |
+
def format_lr_info(optimizer):
|
| 200 |
+
lr_info = ""
|
| 201 |
+
for i, pg in enumerate(optimizer.param_groups):
|
| 202 |
+
lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
|
| 203 |
+
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
|
| 204 |
+
return lr_info
|
| 205 |
+
|
src/training/__init__.py
ADDED
|
File without changes
|
src/training/dcc_tf.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from torchmetrics.functional import(
|
| 11 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
| 12 |
+
signal_noise_ratio as snr,
|
| 13 |
+
signal_distortion_ratio as sdr,
|
| 14 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 15 |
+
|
| 16 |
+
from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding
|
| 17 |
+
|
| 18 |
+
def mod_pad(x, chunk_size, pad):
|
| 19 |
+
# Mod pad the input to perform integer number of
|
| 20 |
+
# inferences
|
| 21 |
+
mod = 0
|
| 22 |
+
if (x.shape[-1] % chunk_size) != 0:
|
| 23 |
+
mod = chunk_size - (x.shape[-1] % chunk_size)
|
| 24 |
+
|
| 25 |
+
x = F.pad(x, (0, mod))
|
| 26 |
+
x = F.pad(x, pad)
|
| 27 |
+
|
| 28 |
+
return x, mod
|
| 29 |
+
|
| 30 |
+
class LayerNormPermuted(nn.LayerNorm):
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
super(LayerNormPermuted, self).__init__(*args, **kwargs)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
x: [B, C, T]
|
| 38 |
+
"""
|
| 39 |
+
x = x.permute(0, 2, 1) # [B, T, C]
|
| 40 |
+
x = super().forward(x)
|
| 41 |
+
x = x.permute(0, 2, 1) # [B, C, T]
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
Depthwise separable convolutions
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
| 49 |
+
padding, dilation):
|
| 50 |
+
super(DepthwiseSeparableConv, self).__init__()
|
| 51 |
+
|
| 52 |
+
self.layers = nn.Sequential(
|
| 53 |
+
nn.Conv1d(in_channels, in_channels, kernel_size, stride,
|
| 54 |
+
padding, groups=in_channels, dilation=dilation),
|
| 55 |
+
LayerNormPermuted(in_channels),
|
| 56 |
+
nn.ReLU(),
|
| 57 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1,
|
| 58 |
+
padding=0),
|
| 59 |
+
LayerNormPermuted(out_channels),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return self.layers(x)
|
| 65 |
+
|
| 66 |
+
class DilatedCausalConvEncoder(nn.Module):
|
| 67 |
+
"""
|
| 68 |
+
A dilated causal convolution based encoder for encoding
|
| 69 |
+
time domain audio input into latent space.
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, channels, num_layers, kernel_size=3):
|
| 72 |
+
super(DilatedCausalConvEncoder, self).__init__()
|
| 73 |
+
self.channels = channels
|
| 74 |
+
self.num_layers = num_layers
|
| 75 |
+
self.kernel_size = kernel_size
|
| 76 |
+
|
| 77 |
+
# Compute buffer lengths for each layer
|
| 78 |
+
# buf_length[i] = (kernel_size - 1) * dilation[i]
|
| 79 |
+
self.buf_lengths = [(kernel_size - 1) * 2**i
|
| 80 |
+
for i in range(num_layers)]
|
| 81 |
+
|
| 82 |
+
# Compute buffer start indices for each layer
|
| 83 |
+
self.buf_indices = [0]
|
| 84 |
+
for i in range(num_layers - 1):
|
| 85 |
+
self.buf_indices.append(
|
| 86 |
+
self.buf_indices[-1] + self.buf_lengths[i])
|
| 87 |
+
|
| 88 |
+
# Dilated causal conv layers aggregate previous context to obtain
|
| 89 |
+
# contexful encoded input.
|
| 90 |
+
_dcc_layers = OrderedDict()
|
| 91 |
+
for i in range(num_layers):
|
| 92 |
+
dcc_layer = DepthwiseSeparableConv(
|
| 93 |
+
channels, channels, kernel_size=3, stride=1,
|
| 94 |
+
padding=0, dilation=2**i)
|
| 95 |
+
_dcc_layers.update({'dcc_%d' % i: dcc_layer})
|
| 96 |
+
self.dcc_layers = nn.Sequential(_dcc_layers)
|
| 97 |
+
|
| 98 |
+
def init_ctx_buf(self, batch_size, device):
|
| 99 |
+
"""
|
| 100 |
+
Returns an initialized context buffer for a given batch size.
|
| 101 |
+
"""
|
| 102 |
+
return torch.zeros(
|
| 103 |
+
(batch_size, self.channels,
|
| 104 |
+
(self.kernel_size - 1) * (2**self.num_layers - 1)),
|
| 105 |
+
device=device)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, ctx_buf):
|
| 108 |
+
"""
|
| 109 |
+
Encodes input audio `x` into latent space, and aggregates
|
| 110 |
+
contextual information in `ctx_buf`. Also generates new context
|
| 111 |
+
buffer with updated context.
|
| 112 |
+
Args:
|
| 113 |
+
x: [B, in_channels, T]
|
| 114 |
+
Input multi-channel audio.
|
| 115 |
+
ctx_buf: {[B, channels, self.buf_length[0]], ...}
|
| 116 |
+
A list of tensors holding context for each dilation
|
| 117 |
+
causal conv layer. (len(ctx_buf) == self.num_layers)
|
| 118 |
+
Returns:
|
| 119 |
+
ctx_buf: {[B, channels, self.buf_length[0]], ...}
|
| 120 |
+
Updated context buffer with output as the
|
| 121 |
+
last element.
|
| 122 |
+
"""
|
| 123 |
+
T = x.shape[-1] # Sequence length
|
| 124 |
+
|
| 125 |
+
for i in range(self.num_layers):
|
| 126 |
+
buf_start_idx = self.buf_indices[i]
|
| 127 |
+
buf_end_idx = self.buf_indices[i] + self.buf_lengths[i]
|
| 128 |
+
|
| 129 |
+
# DCC input: concatenation of current output and context
|
| 130 |
+
dcc_in = torch.cat(
|
| 131 |
+
(ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1)
|
| 132 |
+
|
| 133 |
+
# Push current output to the context buffer
|
| 134 |
+
ctx_buf[..., buf_start_idx:buf_end_idx] = \
|
| 135 |
+
dcc_in[..., -self.buf_lengths[i]:]
|
| 136 |
+
|
| 137 |
+
# Residual connection
|
| 138 |
+
x = x + self.dcc_layers[i](dcc_in)
|
| 139 |
+
|
| 140 |
+
return x, ctx_buf
|
| 141 |
+
|
| 142 |
+
class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer):
|
| 143 |
+
"""
|
| 144 |
+
Adapted from:
|
| 145 |
+
"https://github.com/alexmt-scale/causal-transformer-decoder/blob/"
|
| 146 |
+
"0caf6ad71c46488f76d89845b0123d2550ef792f/"
|
| 147 |
+
"causal_transformer_decoder/model.py#L77"
|
| 148 |
+
"""
|
| 149 |
+
def forward(
|
| 150 |
+
self,
|
| 151 |
+
tgt: Tensor,
|
| 152 |
+
memory: Optional[Tensor] = None,
|
| 153 |
+
chunk_size: int = 1
|
| 154 |
+
) -> Tensor:
|
| 155 |
+
tgt_last_tok = tgt[:, -chunk_size:, :]
|
| 156 |
+
|
| 157 |
+
# self attention part
|
| 158 |
+
tmp_tgt, sa_map = self.self_attn(
|
| 159 |
+
tgt_last_tok,
|
| 160 |
+
tgt,
|
| 161 |
+
tgt,
|
| 162 |
+
attn_mask=None, # not needed because we only care about the last token
|
| 163 |
+
key_padding_mask=None,
|
| 164 |
+
)
|
| 165 |
+
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
|
| 166 |
+
tgt_last_tok = self.norm1(tgt_last_tok)
|
| 167 |
+
|
| 168 |
+
# encoder-decoder attention
|
| 169 |
+
if memory is not None:
|
| 170 |
+
tmp_tgt, ca_map = self.multihead_attn(
|
| 171 |
+
tgt_last_tok,
|
| 172 |
+
memory,
|
| 173 |
+
memory,
|
| 174 |
+
attn_mask=None, # Attend to the entire chunk
|
| 175 |
+
key_padding_mask=None,
|
| 176 |
+
)
|
| 177 |
+
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
|
| 178 |
+
tgt_last_tok = self.norm2(tgt_last_tok)
|
| 179 |
+
|
| 180 |
+
# final feed-forward network
|
| 181 |
+
tmp_tgt = self.linear2(
|
| 182 |
+
self.dropout(self.activation(self.linear1(tgt_last_tok)))
|
| 183 |
+
)
|
| 184 |
+
tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
|
| 185 |
+
tgt_last_tok = self.norm3(tgt_last_tok)
|
| 186 |
+
return tgt_last_tok, sa_map, ca_map
|
| 187 |
+
|
| 188 |
+
class CausalTransformerDecoder(nn.Module):
|
| 189 |
+
"""
|
| 190 |
+
A casual transformer decoder which decodes input vectors using
|
| 191 |
+
precisely `ctx_len` past vectors in the sequence, and using no future
|
| 192 |
+
vectors at all.
|
| 193 |
+
"""
|
| 194 |
+
def __init__(self, model_dim, ctx_len, chunk_size, num_layers,
|
| 195 |
+
nhead, use_pos_enc, ff_dim):
|
| 196 |
+
super(CausalTransformerDecoder, self).__init__()
|
| 197 |
+
self.num_layers = num_layers
|
| 198 |
+
self.model_dim = model_dim
|
| 199 |
+
self.ctx_len = ctx_len
|
| 200 |
+
self.chunk_size = chunk_size
|
| 201 |
+
self.nhead = nhead
|
| 202 |
+
self.use_pos_enc = use_pos_enc
|
| 203 |
+
self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size)
|
| 204 |
+
self.pos_enc = PositionalEncoding(model_dim, max_len=200)
|
| 205 |
+
self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer(
|
| 206 |
+
d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim,
|
| 207 |
+
batch_first=True) for _ in range(num_layers)])
|
| 208 |
+
|
| 209 |
+
def init_ctx_buf(self, batch_size, device):
|
| 210 |
+
return torch.zeros(
|
| 211 |
+
(batch_size, self.num_layers + 1, self.ctx_len, self.model_dim),
|
| 212 |
+
device=device)
|
| 213 |
+
|
| 214 |
+
def _causal_unfold(self, x):
|
| 215 |
+
"""
|
| 216 |
+
Unfolds the sequence into a batch of sequences
|
| 217 |
+
prepended with `ctx_len` previous values.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
x: [B, ctx_len + L, C]
|
| 221 |
+
ctx_len: int
|
| 222 |
+
Returns:
|
| 223 |
+
[B * L, ctx_len + 1, C]
|
| 224 |
+
"""
|
| 225 |
+
B, T, C = x.shape
|
| 226 |
+
x = x.permute(0, 2, 1) # [B, C, ctx_len + L]
|
| 227 |
+
x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1]
|
| 228 |
+
x = x.permute(0, 2, 1)
|
| 229 |
+
x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size)
|
| 230 |
+
x = x.reshape(-1, C, self.ctx_len + self.chunk_size)
|
| 231 |
+
x = x.permute(0, 2, 1)
|
| 232 |
+
return x
|
| 233 |
+
|
| 234 |
+
def forward(self, tgt, mem, ctx_buf, probe=False):
|
| 235 |
+
"""
|
| 236 |
+
Args:
|
| 237 |
+
x: [B, model_dim, T]
|
| 238 |
+
ctx_buf: [B, num_layers, model_dim, ctx_len]
|
| 239 |
+
"""
|
| 240 |
+
mem, _ = mod_pad(mem, self.chunk_size, (0, 0))
|
| 241 |
+
tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0))
|
| 242 |
+
|
| 243 |
+
# Input sequence length
|
| 244 |
+
B, C, T = tgt.shape
|
| 245 |
+
|
| 246 |
+
tgt = tgt.permute(0, 2, 1)
|
| 247 |
+
mem = mem.permute(0, 2, 1)
|
| 248 |
+
|
| 249 |
+
# Prepend mem with the context
|
| 250 |
+
mem = torch.cat((ctx_buf[:, 0, :, :], mem), dim=1)
|
| 251 |
+
ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :]
|
| 252 |
+
mem_ctx = self._causal_unfold(mem)
|
| 253 |
+
if self.use_pos_enc:
|
| 254 |
+
mem_ctx = mem_ctx + self.pos_enc(mem_ctx)
|
| 255 |
+
|
| 256 |
+
# Attention chunk size: required to ensure the model
|
| 257 |
+
# wouldn't trigger an out-of-memory error when working
|
| 258 |
+
# on long sequences.
|
| 259 |
+
K = 1000
|
| 260 |
+
|
| 261 |
+
for i, tf_dec_layer in enumerate(self.tf_dec_layers):
|
| 262 |
+
# Update the tgt with context
|
| 263 |
+
tgt = torch.cat((ctx_buf[:, i + 1, :, :], tgt), dim=1)
|
| 264 |
+
ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :]
|
| 265 |
+
|
| 266 |
+
# Compute encoded output
|
| 267 |
+
tgt_ctx = self._causal_unfold(tgt)
|
| 268 |
+
if self.use_pos_enc and i == 0:
|
| 269 |
+
tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx)
|
| 270 |
+
tgt = torch.zeros_like(tgt_ctx)[:, -self.chunk_size:, :]
|
| 271 |
+
for i in range(int(math.ceil(tgt.shape[0] / K))):
|
| 272 |
+
tgt[i*K:(i+1)*K], _sa_map, _ca_map = tf_dec_layer(
|
| 273 |
+
tgt_ctx[i*K:(i+1)*K], mem_ctx[i*K:(i+1)*K],
|
| 274 |
+
self.chunk_size)
|
| 275 |
+
tgt = tgt.reshape(B, T, C)
|
| 276 |
+
|
| 277 |
+
tgt = tgt.permute(0, 2, 1)
|
| 278 |
+
if mod != 0:
|
| 279 |
+
tgt = tgt[..., :-mod]
|
| 280 |
+
|
| 281 |
+
return tgt, ctx_buf
|
| 282 |
+
|
| 283 |
+
class MaskNet(nn.Module):
|
| 284 |
+
def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len,
|
| 285 |
+
dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection, proj):
|
| 286 |
+
super(MaskNet, self).__init__()
|
| 287 |
+
self.skip_connection = skip_connection
|
| 288 |
+
self.proj = proj
|
| 289 |
+
|
| 290 |
+
# Encoder based on dilated causal convolutions.
|
| 291 |
+
self.encoder = DilatedCausalConvEncoder(channels=enc_dim,
|
| 292 |
+
num_layers=num_enc_layers)
|
| 293 |
+
|
| 294 |
+
# Project between encoder and decoder dimensions
|
| 295 |
+
self.proj_e2d_e = nn.Sequential(
|
| 296 |
+
nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0,
|
| 297 |
+
groups=dec_dim),
|
| 298 |
+
nn.ReLU())
|
| 299 |
+
self.proj_e2d_l = nn.Sequential(
|
| 300 |
+
nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0,
|
| 301 |
+
groups=dec_dim),
|
| 302 |
+
nn.ReLU())
|
| 303 |
+
self.proj_d2e = nn.Sequential(
|
| 304 |
+
nn.Conv1d(dec_dim, enc_dim, kernel_size=1, stride=1, padding=0,
|
| 305 |
+
groups=dec_dim),
|
| 306 |
+
nn.ReLU())
|
| 307 |
+
|
| 308 |
+
# Transformer decoder that operates on chunks of size
|
| 309 |
+
# buffer size.
|
| 310 |
+
self.decoder = CausalTransformerDecoder(
|
| 311 |
+
model_dim=dec_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size,
|
| 312 |
+
num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc,
|
| 313 |
+
ff_dim=2 * dec_dim)
|
| 314 |
+
|
| 315 |
+
def forward(self, x, l, enc_buf, dec_buf):
|
| 316 |
+
"""
|
| 317 |
+
Generates a mask based on encoded input `e` and the one-hot
|
| 318 |
+
label `label`.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
x: [B, C, T]
|
| 322 |
+
Input audio sequence
|
| 323 |
+
l: [B, C]
|
| 324 |
+
Label embedding
|
| 325 |
+
ctx_buf: {[B, C, <receptive field of the layer>], ...}
|
| 326 |
+
List of context buffers maintained by DCC encoder
|
| 327 |
+
"""
|
| 328 |
+
# Enocder the label integrated input
|
| 329 |
+
e, enc_buf = self.encoder(x, enc_buf)
|
| 330 |
+
|
| 331 |
+
# Label integration
|
| 332 |
+
l = l.unsqueeze(2) * e
|
| 333 |
+
|
| 334 |
+
# Project to `dec_dim` dimensions
|
| 335 |
+
if self.proj:
|
| 336 |
+
e = self.proj_e2d_e(e)
|
| 337 |
+
m = self.proj_e2d_l(l)
|
| 338 |
+
# Cross-attention to predict the mask
|
| 339 |
+
m, dec_buf = self.decoder(m, e, dec_buf)
|
| 340 |
+
else:
|
| 341 |
+
# Cross-attention to predict the mask
|
| 342 |
+
m, dec_buf = self.decoder(l, e, dec_buf)
|
| 343 |
+
|
| 344 |
+
# Project mask to encoder dimensions
|
| 345 |
+
if self.proj:
|
| 346 |
+
m = self.proj_d2e(m)
|
| 347 |
+
|
| 348 |
+
# Final mask after residual connection
|
| 349 |
+
if self.skip_connection:
|
| 350 |
+
m = l + m
|
| 351 |
+
|
| 352 |
+
return m, enc_buf, dec_buf
|
| 353 |
+
|
| 354 |
+
class Net(nn.Module):
|
| 355 |
+
def __init__(self, label_len, L=8,
|
| 356 |
+
enc_dim=512, num_enc_layers=10,
|
| 357 |
+
dec_dim=256, dec_buf_len=100, num_dec_layers=2,
|
| 358 |
+
dec_chunk_size=72, out_buf_len=2,
|
| 359 |
+
use_pos_enc=True, skip_connection=True, proj=True, lookahead=True):
|
| 360 |
+
super(Net, self).__init__()
|
| 361 |
+
self.L = L
|
| 362 |
+
self.out_buf_len = out_buf_len
|
| 363 |
+
self.enc_dim = enc_dim
|
| 364 |
+
self.lookahead = lookahead
|
| 365 |
+
|
| 366 |
+
# Input conv to convert input audio to a latent representation
|
| 367 |
+
kernel_size = 3 * L if lookahead else L
|
| 368 |
+
self.in_conv = nn.Sequential(
|
| 369 |
+
nn.Conv1d(in_channels=1,
|
| 370 |
+
out_channels=enc_dim, kernel_size=kernel_size, stride=L,
|
| 371 |
+
padding=0, bias=False),
|
| 372 |
+
nn.ReLU())
|
| 373 |
+
|
| 374 |
+
# Label embedding layer
|
| 375 |
+
self.label_embedding = nn.Sequential(
|
| 376 |
+
nn.Linear(label_len, 512),
|
| 377 |
+
nn.LayerNorm(512),
|
| 378 |
+
nn.ReLU(),
|
| 379 |
+
nn.Linear(512, enc_dim),
|
| 380 |
+
nn.LayerNorm(enc_dim),
|
| 381 |
+
nn.ReLU())
|
| 382 |
+
|
| 383 |
+
# Mask generator
|
| 384 |
+
self.mask_gen = MaskNet(
|
| 385 |
+
enc_dim=enc_dim, num_enc_layers=num_enc_layers,
|
| 386 |
+
dec_dim=dec_dim, dec_buf_len=dec_buf_len,
|
| 387 |
+
dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers,
|
| 388 |
+
use_pos_enc=use_pos_enc, skip_connection=skip_connection, proj=proj)
|
| 389 |
+
|
| 390 |
+
# Output conv layer
|
| 391 |
+
self.out_conv = nn.Sequential(
|
| 392 |
+
nn.ConvTranspose1d(
|
| 393 |
+
in_channels=enc_dim, out_channels=1,
|
| 394 |
+
kernel_size=(out_buf_len + 1) * L,
|
| 395 |
+
stride=L,
|
| 396 |
+
padding=out_buf_len * L, bias=False),
|
| 397 |
+
nn.Tanh())
|
| 398 |
+
|
| 399 |
+
def init_buffers(self, batch_size, device):
|
| 400 |
+
enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device)
|
| 401 |
+
dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device)
|
| 402 |
+
out_buf = torch.zeros(batch_size, self.enc_dim, self.out_buf_len,
|
| 403 |
+
device=device)
|
| 404 |
+
return enc_buf, dec_buf, out_buf
|
| 405 |
+
|
| 406 |
+
def forward(self, x, label, init_enc_buf=None, init_dec_buf=None,
|
| 407 |
+
init_out_buf=None, pad=True):
|
| 408 |
+
"""
|
| 409 |
+
Extracts the audio corresponding to the `label` in the given
|
| 410 |
+
`mixture`. Generates `chunk_size` samples per iteration.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
mixed: [B, n_mics, T]
|
| 414 |
+
input audio mixture
|
| 415 |
+
label: [B, num_labels]
|
| 416 |
+
one hot label
|
| 417 |
+
Returns:
|
| 418 |
+
out: [B, n_spk, T]
|
| 419 |
+
extracted audio with sounds corresponding to the `label`
|
| 420 |
+
"""
|
| 421 |
+
mod = 0
|
| 422 |
+
if pad:
|
| 423 |
+
pad_size = (self.L, self.L) if self.lookahead else (0, 0)
|
| 424 |
+
x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size)
|
| 425 |
+
|
| 426 |
+
if init_enc_buf is None or init_dec_buf is None or init_out_buf is None:
|
| 427 |
+
assert init_enc_buf is None and \
|
| 428 |
+
init_dec_buf is None and \
|
| 429 |
+
init_out_buf is None, \
|
| 430 |
+
"Both buffers have to initialized, or " \
|
| 431 |
+
"both of them have to be None."
|
| 432 |
+
enc_buf, dec_buf, out_buf = self.init_buffers(
|
| 433 |
+
x.shape[0], x.device)
|
| 434 |
+
else:
|
| 435 |
+
enc_buf, dec_buf, out_buf = \
|
| 436 |
+
init_enc_buf, init_dec_buf, init_out_buf
|
| 437 |
+
|
| 438 |
+
# Generate latent space representation of the input
|
| 439 |
+
x = self.in_conv(x)
|
| 440 |
+
|
| 441 |
+
# Generate label embedding
|
| 442 |
+
l = self.label_embedding(label) # [B, label_len] --> [B, channels]
|
| 443 |
+
|
| 444 |
+
# Generate mask corresponding to the label
|
| 445 |
+
m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf)
|
| 446 |
+
|
| 447 |
+
# Apply mask and decode
|
| 448 |
+
x = x * m
|
| 449 |
+
x = torch.cat((out_buf, x), dim=-1)
|
| 450 |
+
out_buf = x[..., -self.out_buf_len:]
|
| 451 |
+
x = self.out_conv(x)
|
| 452 |
+
|
| 453 |
+
# Remove mod padding, if present.
|
| 454 |
+
if mod != 0:
|
| 455 |
+
x = x[:, :, :-mod]
|
| 456 |
+
|
| 457 |
+
if init_enc_buf is None:
|
| 458 |
+
return x
|
| 459 |
+
else:
|
| 460 |
+
return x, enc_buf, dec_buf, out_buf
|
| 461 |
+
|
| 462 |
+
# Define optimizer, loss and metrics
|
| 463 |
+
|
| 464 |
+
def optimizer(model, data_parallel=False, **kwargs):
|
| 465 |
+
return optim.Adam(model.parameters(), **kwargs)
|
| 466 |
+
|
| 467 |
+
def loss(pred, tgt):
|
| 468 |
+
return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean()
|
| 469 |
+
|
| 470 |
+
def metrics(mixed, output, gt):
|
| 471 |
+
""" Function to compute metrics """
|
| 472 |
+
metrics = {}
|
| 473 |
+
|
| 474 |
+
def metric_i(metric, src, pred, tgt):
|
| 475 |
+
_vals = []
|
| 476 |
+
for s, t, p in zip(src, tgt, pred):
|
| 477 |
+
_vals.append((metric(p, t) - metric(s, t)).cpu().item())
|
| 478 |
+
return _vals
|
| 479 |
+
|
| 480 |
+
for m_fn in [snr, si_snr]:
|
| 481 |
+
metrics[m_fn.__name__] = metric_i(m_fn,
|
| 482 |
+
mixed[:, :gt.shape[1], :],
|
| 483 |
+
output,
|
| 484 |
+
gt)
|
| 485 |
+
|
| 486 |
+
return metrics
|
src/training/eval.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script to evaluate the model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import importlib
|
| 7 |
+
import multiprocessing
|
| 8 |
+
import os, glob
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
| 17 |
+
from tqdm import tqdm # pylint: disable=unused-import
|
| 18 |
+
from torchmetrics.functional import(
|
| 19 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
| 20 |
+
signal_noise_ratio as snr,
|
| 21 |
+
signal_distortion_ratio as sdr,
|
| 22 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 23 |
+
|
| 24 |
+
from src.helpers import utils
|
| 25 |
+
from src.training.synthetic_dataset import FSDSoundScapesDataset, tensorboard_add_metrics
|
| 26 |
+
from src.training.synthetic_dataset import tensorboard_add_sample
|
| 27 |
+
|
| 28 |
+
def test_epoch(model: nn.Module, device: torch.device,
|
| 29 |
+
test_loader: torch.utils.data.dataloader.DataLoader,
|
| 30 |
+
n_items: int, loss_fn, metrics_fn,
|
| 31 |
+
profiling: bool = False, epoch: int = 0,
|
| 32 |
+
writer: SummaryWriter = None, data_params = None) -> float:
|
| 33 |
+
"""
|
| 34 |
+
Evaluate the network.
|
| 35 |
+
"""
|
| 36 |
+
model.eval()
|
| 37 |
+
metrics = {}
|
| 38 |
+
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
for batch_idx, (mixed, label, gt) in \
|
| 41 |
+
enumerate(tqdm(test_loader, desc='Test', ncols=100)):
|
| 42 |
+
mixed = mixed.to(device)
|
| 43 |
+
label = label.to(device)
|
| 44 |
+
gt = gt.to(device)
|
| 45 |
+
|
| 46 |
+
# Run through the model
|
| 47 |
+
with profile(activities=[ProfilerActivity.CPU],
|
| 48 |
+
record_shapes=True) as prof:
|
| 49 |
+
with record_function("model_inference"):
|
| 50 |
+
output = model(mixed, label)
|
| 51 |
+
if profiling:
|
| 52 |
+
logging.info(
|
| 53 |
+
prof.key_averages().table(sort_by="self_cpu_time_total",
|
| 54 |
+
row_limit=20))
|
| 55 |
+
|
| 56 |
+
# Compute loss
|
| 57 |
+
loss = loss_fn(output, gt)
|
| 58 |
+
|
| 59 |
+
# Compute metrics
|
| 60 |
+
metrics_batch = metrics_fn(mixed, output, gt)
|
| 61 |
+
metrics_batch['loss'] = [loss.item()]
|
| 62 |
+
metrics_batch['runtime'] = [prof.profiler.self_cpu_time_total/1000]
|
| 63 |
+
for k in metrics_batch.keys():
|
| 64 |
+
if not k in metrics:
|
| 65 |
+
metrics[k] = metrics_batch[k]
|
| 66 |
+
else:
|
| 67 |
+
metrics[k] += metrics_batch[k]
|
| 68 |
+
|
| 69 |
+
if writer is not None:
|
| 70 |
+
if batch_idx == 0:
|
| 71 |
+
tensorboard_add_sample(
|
| 72 |
+
writer, tag='Test',
|
| 73 |
+
sample=(mixed[:8], label[:8], gt[:8], output[:8]),
|
| 74 |
+
step=epoch, params=data_params)
|
| 75 |
+
tensorboard_add_metrics(
|
| 76 |
+
writer, tag='Test', metrics=metrics_batch, label=label,
|
| 77 |
+
step=epoch)
|
| 78 |
+
|
| 79 |
+
if n_items is not None and batch_idx == (n_items - 1):
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
| 83 |
+
avg_metrics_str = "Test:"
|
| 84 |
+
for m in avg_metrics.keys():
|
| 85 |
+
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
| 86 |
+
logging.info(avg_metrics_str)
|
| 87 |
+
|
| 88 |
+
return avg_metrics
|
| 89 |
+
|
| 90 |
+
def evaluate(network, args: argparse.Namespace):
|
| 91 |
+
"""
|
| 92 |
+
Evaluate the model on a given dataset.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
# Load dataset
|
| 96 |
+
data_test = FSDSoundScapesDataset(**args.test_data)
|
| 97 |
+
logging.info("Loaded test dataset at %s containing %d elements" %
|
| 98 |
+
(args.test_data['input_dir'], len(data_test)))
|
| 99 |
+
|
| 100 |
+
# Set up the device and workers.
|
| 101 |
+
use_cuda = args.use_cuda and torch.cuda.is_available()
|
| 102 |
+
if use_cuda:
|
| 103 |
+
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
| 104 |
+
else range(torch.cuda.device_count())
|
| 105 |
+
device_ids = [_ for _ in gpu_ids]
|
| 106 |
+
data_parallel = len(device_ids) > 1
|
| 107 |
+
device = 'cuda:%d' % device_ids[0]
|
| 108 |
+
torch.cuda.set_device(device_ids[0])
|
| 109 |
+
logging.info("Using CUDA devices: %s" % str(device_ids))
|
| 110 |
+
else:
|
| 111 |
+
data_parallel = False
|
| 112 |
+
device = torch.device('cpu')
|
| 113 |
+
logging.info("Using device: CPU")
|
| 114 |
+
|
| 115 |
+
# Set multiprocessing params
|
| 116 |
+
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
| 117 |
+
kwargs = {
|
| 118 |
+
'num_workers': num_workers,
|
| 119 |
+
'pin_memory': True
|
| 120 |
+
} if use_cuda else {}
|
| 121 |
+
|
| 122 |
+
# Set up data loader
|
| 123 |
+
test_loader = torch.utils.data.DataLoader(data_test,
|
| 124 |
+
batch_size=args.eval_batch_size,
|
| 125 |
+
**kwargs)
|
| 126 |
+
|
| 127 |
+
# Set up model
|
| 128 |
+
model = network.Net(**args.model_params)
|
| 129 |
+
if use_cuda and data_parallel:
|
| 130 |
+
model = nn.DataParallel(model, device_ids=device_ids)
|
| 131 |
+
logging.info("Using data parallel model")
|
| 132 |
+
model.to(device)
|
| 133 |
+
|
| 134 |
+
# Load weights
|
| 135 |
+
if args.pretrain_path == "best":
|
| 136 |
+
ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt'))
|
| 137 |
+
ckpts.sort(
|
| 138 |
+
key=lambda _: int(os.path.splitext(os.path.basename(_))[0]))
|
| 139 |
+
val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric]
|
| 140 |
+
best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__)
|
| 141 |
+
args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch)
|
| 142 |
+
logging.info(
|
| 143 |
+
"Found 'best' validation %s=%.02f at %s" %
|
| 144 |
+
(args.base_metric, val_metrics[best_epoch], args.pretrain_path))
|
| 145 |
+
if args.pretrain_path != "":
|
| 146 |
+
utils.load_checkpoint(
|
| 147 |
+
args.pretrain_path, model, data_parallel=data_parallel)
|
| 148 |
+
logging.info("Loaded pretrain weights from %s" % args.pretrain_path)
|
| 149 |
+
|
| 150 |
+
# Evaluate
|
| 151 |
+
try:
|
| 152 |
+
return test_epoch(
|
| 153 |
+
model, device, test_loader, args.n_items, network.loss,
|
| 154 |
+
network.metrics, args.profiling)
|
| 155 |
+
except KeyboardInterrupt:
|
| 156 |
+
print("Interrupted")
|
| 157 |
+
except Exception as _: # pylint: disable=broad-except
|
| 158 |
+
import traceback # pylint: disable=import-outside-toplevel
|
| 159 |
+
traceback.print_exc()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == '__main__':
|
| 163 |
+
parser = argparse.ArgumentParser()
|
| 164 |
+
# Data Params
|
| 165 |
+
parser.add_argument('experiments', nargs='+', type=str,
|
| 166 |
+
default=None,
|
| 167 |
+
help="List of experiments to evaluate. "
|
| 168 |
+
"Provide only one experiment when providing "
|
| 169 |
+
"pretrained path. If pretrianed path is not "
|
| 170 |
+
"provided, epoch with best validation metric "
|
| 171 |
+
"is used for evaluation.")
|
| 172 |
+
parser.add_argument('--results', type=str, default="",
|
| 173 |
+
help="Path to the CSV file to store results.")
|
| 174 |
+
|
| 175 |
+
# System params
|
| 176 |
+
parser.add_argument('--n_items', type=int, default=None,
|
| 177 |
+
help="Number of items to test.")
|
| 178 |
+
parser.add_argument('--pretrain_path', type=str, default="best",
|
| 179 |
+
help="Path to pretrained weights")
|
| 180 |
+
parser.add_argument('--profiling', dest='profiling', action='store_true',
|
| 181 |
+
help="Enable or disable profiling.")
|
| 182 |
+
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
| 183 |
+
help="Whether to use cuda")
|
| 184 |
+
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
| 185 |
+
help="List of GPU ids used for training. "
|
| 186 |
+
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
| 187 |
+
args = parser.parse_args()
|
| 188 |
+
|
| 189 |
+
results = []
|
| 190 |
+
|
| 191 |
+
for exp_dir in args.experiments:
|
| 192 |
+
eval_args = argparse.Namespace(**vars(args))
|
| 193 |
+
eval_args.exp_dir = exp_dir
|
| 194 |
+
|
| 195 |
+
utils.set_logger(os.path.join(exp_dir, 'eval.log'))
|
| 196 |
+
logging.info("Evaluating %s ..." % exp_dir)
|
| 197 |
+
|
| 198 |
+
# Load model and training params
|
| 199 |
+
params = utils.Params(os.path.join(exp_dir, 'config.json'))
|
| 200 |
+
for k, v in params.__dict__.items():
|
| 201 |
+
vars(eval_args)[k] = v
|
| 202 |
+
|
| 203 |
+
network = importlib.import_module(eval_args.model)
|
| 204 |
+
logging.info("Imported the model from '%s'." % eval_args.model)
|
| 205 |
+
|
| 206 |
+
curr_res = evaluate(network, eval_args)
|
| 207 |
+
curr_res['experiment'] = os.path.basename(exp_dir)
|
| 208 |
+
results.append(curr_res)
|
| 209 |
+
|
| 210 |
+
del eval_args
|
| 211 |
+
|
| 212 |
+
if args.results != "":
|
| 213 |
+
print("Writing results to %s" % args.results)
|
| 214 |
+
pd.DataFrame(results).to_csv(args.results, index=False)
|
src/training/synthetic_dataset.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Torch dataset object for synthetically rendered spatial data.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import scaper
|
| 15 |
+
import torch
|
| 16 |
+
import torchaudio
|
| 17 |
+
import torchaudio.transforms as AT
|
| 18 |
+
from random import randrange
|
| 19 |
+
|
| 20 |
+
class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore
|
| 21 |
+
"""
|
| 22 |
+
Base class for FSD Sound Scapes dataset
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
_labels = [
|
| 26 |
+
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
| 27 |
+
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
|
| 28 |
+
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
|
| 29 |
+
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
|
| 30 |
+
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
|
| 31 |
+
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
|
| 32 |
+
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
|
| 33 |
+
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
|
| 34 |
+
"Trumpet", "Violin_or_fiddle", "Writing"]
|
| 35 |
+
|
| 36 |
+
def __init__(self, input_dir, dset='', sr=None,
|
| 37 |
+
resample_rate=None, max_num_targets=1):
|
| 38 |
+
assert dset in ['train', 'val', 'test'], \
|
| 39 |
+
"`dset` must be one of ['train', 'val', 'test']"
|
| 40 |
+
self.dset = dset
|
| 41 |
+
self.max_num_targets = max_num_targets
|
| 42 |
+
self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset)
|
| 43 |
+
if dset in ['train', 'val']:
|
| 44 |
+
self.bg_dir = os.path.join(
|
| 45 |
+
input_dir,
|
| 46 |
+
'TAU-acoustic-sounds/'
|
| 47 |
+
'TAU-urban-acoustic-scenes-2019-development')
|
| 48 |
+
else:
|
| 49 |
+
self.bg_dir = os.path.join(
|
| 50 |
+
input_dir,
|
| 51 |
+
'TAU-acoustic-sounds/'
|
| 52 |
+
'TAU-urban-acoustic-scenes-2019-evaluation')
|
| 53 |
+
logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" %
|
| 54 |
+
(dset, self.fg_dir, self.bg_dir))
|
| 55 |
+
|
| 56 |
+
self.samples = sorted(list(
|
| 57 |
+
Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*')))
|
| 58 |
+
|
| 59 |
+
jamsfile = os.path.join(self.samples[0], 'mixture.jams')
|
| 60 |
+
_, jams, _, _ = scaper.generate_from_jams(
|
| 61 |
+
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
| 62 |
+
_sr = jams['annotations'][0]['sandbox']['scaper']['sr']
|
| 63 |
+
assert _sr == sr, "Sampling rate provided does not match the data"
|
| 64 |
+
|
| 65 |
+
if resample_rate is not None:
|
| 66 |
+
self.resampler = AT.Resample(sr, resample_rate)
|
| 67 |
+
self.sr = resample_rate
|
| 68 |
+
else:
|
| 69 |
+
self.resampler = lambda a: a
|
| 70 |
+
self.sr = sr
|
| 71 |
+
|
| 72 |
+
def _get_label_vector(self, labels):
|
| 73 |
+
"""
|
| 74 |
+
Generates a multi-hot vector corresponding to `labels`.
|
| 75 |
+
"""
|
| 76 |
+
vector = torch.zeros(len(FSDSoundScapesDataset._labels))
|
| 77 |
+
|
| 78 |
+
for label in labels:
|
| 79 |
+
idx = FSDSoundScapesDataset._labels.index(label)
|
| 80 |
+
assert vector[idx] == 0, "Repeated labels"
|
| 81 |
+
vector[idx] = 1
|
| 82 |
+
|
| 83 |
+
return vector
|
| 84 |
+
|
| 85 |
+
def __len__(self):
|
| 86 |
+
return len(self.samples)
|
| 87 |
+
|
| 88 |
+
def __getitem__(self, idx):
|
| 89 |
+
sample_path = self.samples[idx]
|
| 90 |
+
jamsfile = os.path.join(sample_path, 'mixture.jams')
|
| 91 |
+
|
| 92 |
+
mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams(
|
| 93 |
+
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
| 94 |
+
isolated_events = {}
|
| 95 |
+
for e, a in zip(ann_list, event_audio_list[1:]):
|
| 96 |
+
# 0th event is background
|
| 97 |
+
isolated_events[e[2]] = a
|
| 98 |
+
gt_events = list(pd.read_csv(
|
| 99 |
+
os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label'])
|
| 100 |
+
|
| 101 |
+
mixture = torch.from_numpy(mixture).permute(1, 0)
|
| 102 |
+
mixture = self.resampler(mixture.to(torch.float))
|
| 103 |
+
|
| 104 |
+
if self.dset == 'train':
|
| 105 |
+
labels = random.sample(gt_events, randrange(1,self.max_num_targets+1))
|
| 106 |
+
elif self.dset == 'val':
|
| 107 |
+
labels = gt_events[:idx%self.max_num_targets+1]
|
| 108 |
+
elif self.dset == 'test':
|
| 109 |
+
labels = gt_events[:self.max_num_targets]
|
| 110 |
+
label_vector = self._get_label_vector(labels)
|
| 111 |
+
|
| 112 |
+
gt = torch.zeros_like(
|
| 113 |
+
torch.from_numpy(event_audio_list[1]).permute(1, 0))
|
| 114 |
+
for l in labels:
|
| 115 |
+
gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0)
|
| 116 |
+
gt = self.resampler(gt.to(torch.float))
|
| 117 |
+
|
| 118 |
+
return mixture, label_vector, gt #, jams
|
| 119 |
+
|
| 120 |
+
def tensorboard_add_sample(writer, tag, sample, step, params):
|
| 121 |
+
"""
|
| 122 |
+
Adds a sample of FSDSynthDataset to tensorboard.
|
| 123 |
+
"""
|
| 124 |
+
if params['resample_rate'] is not None:
|
| 125 |
+
sr = params['resample_rate']
|
| 126 |
+
else:
|
| 127 |
+
sr = params['sr']
|
| 128 |
+
resample_rate = 16000 if sr > 16000 else sr
|
| 129 |
+
|
| 130 |
+
m, l, gt, o = sample
|
| 131 |
+
m, gt, o = (
|
| 132 |
+
torchaudio.functional.resample(_, sr, resample_rate).cpu()
|
| 133 |
+
for _ in (m, gt, o))
|
| 134 |
+
|
| 135 |
+
def _add_audio(a, audio_tag, axis, plt_title):
|
| 136 |
+
for i, ch in enumerate(a):
|
| 137 |
+
axis.plot(ch, label='mic %d' % i)
|
| 138 |
+
writer.add_audio(
|
| 139 |
+
'%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate)
|
| 140 |
+
axis.set_title(plt_title)
|
| 141 |
+
axis.legend()
|
| 142 |
+
|
| 143 |
+
for b in range(m.shape[0]):
|
| 144 |
+
label = []
|
| 145 |
+
for i in range(len(l[b, :])):
|
| 146 |
+
if l[b, i] == 1:
|
| 147 |
+
label.append(FSDSoundScapesDataset._labels[i])
|
| 148 |
+
|
| 149 |
+
# Add waveforms
|
| 150 |
+
rows = 3 # input, output, gt
|
| 151 |
+
fig = plt.figure(figsize=(10, 2 * rows))
|
| 152 |
+
axes = fig.subplots(rows, 1, sharex=True)
|
| 153 |
+
_add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed")
|
| 154 |
+
_add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label)
|
| 155 |
+
_add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label)
|
| 156 |
+
writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step)
|
| 157 |
+
|
| 158 |
+
def tensorboard_add_metrics(writer, tag, metrics, label, step):
|
| 159 |
+
"""
|
| 160 |
+
Add metrics to tensorboard.
|
| 161 |
+
"""
|
| 162 |
+
vals = np.asarray(metrics['scale_invariant_signal_noise_ratio'])
|
| 163 |
+
|
| 164 |
+
writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step)
|
| 165 |
+
|
| 166 |
+
label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label]
|
| 167 |
+
for l, v in zip(label_names, vals):
|
| 168 |
+
writer.add_histogram('%s/%s' % (tag, l), v, step)
|
src/training/train.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The main training script for training on synthetic data
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import multiprocessing
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 18 |
+
from tqdm import tqdm # pylint: disable=unused-import
|
| 19 |
+
from torchmetrics.functional import(
|
| 20 |
+
scale_invariant_signal_noise_ratio as si_snr,
|
| 21 |
+
signal_noise_ratio as snr,
|
| 22 |
+
signal_distortion_ratio as sdr,
|
| 23 |
+
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 24 |
+
|
| 25 |
+
from src.helpers import utils
|
| 26 |
+
from src.training.eval import test_epoch
|
| 27 |
+
from src.training.synthetic_dataset import FSDSoundScapesDataset as Dataset
|
| 28 |
+
from src.training.synthetic_dataset import tensorboard_add_sample
|
| 29 |
+
|
| 30 |
+
def train_epoch(model: nn.Module, device: torch.device,
|
| 31 |
+
optimizer: optim.Optimizer,
|
| 32 |
+
train_loader: torch.utils.data.dataloader.DataLoader,
|
| 33 |
+
n_items: int, epoch: int = 0,
|
| 34 |
+
writer: SummaryWriter = None, data_params = None) -> float:
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
Train a single epoch.
|
| 38 |
+
"""
|
| 39 |
+
# Set the model to training.
|
| 40 |
+
model.train()
|
| 41 |
+
|
| 42 |
+
# Training loop
|
| 43 |
+
losses = []
|
| 44 |
+
metrics = {}
|
| 45 |
+
|
| 46 |
+
with tqdm(total=len(train_loader), desc='Train', ncols=100) as t:
|
| 47 |
+
for batch_idx, (mixed, label, gt) in enumerate(train_loader):
|
| 48 |
+
mixed = mixed.to(device)
|
| 49 |
+
label = label.to(device)
|
| 50 |
+
gt = gt.to(device)
|
| 51 |
+
|
| 52 |
+
# Reset grad
|
| 53 |
+
optimizer.zero_grad()
|
| 54 |
+
|
| 55 |
+
# Run through the model
|
| 56 |
+
output = model(mixed, label)
|
| 57 |
+
|
| 58 |
+
# Compute loss
|
| 59 |
+
loss = network.loss(output, gt)
|
| 60 |
+
|
| 61 |
+
losses.append(loss.item())
|
| 62 |
+
|
| 63 |
+
# Backpropagation
|
| 64 |
+
loss.backward()
|
| 65 |
+
|
| 66 |
+
# Gradient clipping
|
| 67 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
| 68 |
+
|
| 69 |
+
# Update the weights
|
| 70 |
+
optimizer.step()
|
| 71 |
+
|
| 72 |
+
metrics_batch = network.metrics(mixed.detach(), output.detach(),
|
| 73 |
+
gt.detach())
|
| 74 |
+
for k in metrics_batch.keys():
|
| 75 |
+
if not k in metrics:
|
| 76 |
+
metrics[k] = metrics_batch[k]
|
| 77 |
+
else:
|
| 78 |
+
metrics[k] += metrics_batch[k]
|
| 79 |
+
|
| 80 |
+
if writer is not None and batch_idx == 0:
|
| 81 |
+
tensorboard_add_sample(
|
| 82 |
+
writer, tag='Train',
|
| 83 |
+
sample=(mixed.detach()[:8], label.detach()[:8],
|
| 84 |
+
gt.detach()[:8], output.detach()[:8]),
|
| 85 |
+
step=epoch, params=data_params)
|
| 86 |
+
|
| 87 |
+
# Show current loss in the progress meter
|
| 88 |
+
t.set_postfix(loss='%.05f'%loss.item())
|
| 89 |
+
t.update()
|
| 90 |
+
|
| 91 |
+
if n_items is not None and batch_idx == n_items:
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
| 95 |
+
avg_metrics['loss'] = np.mean(losses)
|
| 96 |
+
avg_metrics_str = "Train:"
|
| 97 |
+
for m in avg_metrics.keys():
|
| 98 |
+
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
| 99 |
+
logging.info(avg_metrics_str)
|
| 100 |
+
|
| 101 |
+
return avg_metrics
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def train(args: argparse.Namespace):
|
| 105 |
+
"""
|
| 106 |
+
Train the network.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# Load dataset
|
| 110 |
+
data_train = Dataset(**args.train_data)
|
| 111 |
+
logging.info("Loaded train dataset at %s containing %d elements" %
|
| 112 |
+
(args.train_data['input_dir'], len(data_train)))
|
| 113 |
+
data_val = Dataset(**args.val_data)
|
| 114 |
+
logging.info("Loaded test dataset at %s containing %d elements" %
|
| 115 |
+
(args.val_data['input_dir'], len(data_val)))
|
| 116 |
+
|
| 117 |
+
# Set up the device and workers.
|
| 118 |
+
use_cuda = args.use_cuda and torch.cuda.is_available()
|
| 119 |
+
if use_cuda:
|
| 120 |
+
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
| 121 |
+
else range(torch.cuda.device_count())
|
| 122 |
+
device_ids = [_ for _ in gpu_ids]
|
| 123 |
+
data_parallel = len(device_ids) > 1
|
| 124 |
+
device = 'cuda:%d' % device_ids[0]
|
| 125 |
+
torch.cuda.set_device(device_ids[0])
|
| 126 |
+
logging.info("Using CUDA devices: %s" % str(device_ids))
|
| 127 |
+
else:
|
| 128 |
+
data_parallel = False
|
| 129 |
+
device = torch.device('cpu')
|
| 130 |
+
logging.info("Using device: CPU")
|
| 131 |
+
|
| 132 |
+
# Set multiprocessing params
|
| 133 |
+
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
| 134 |
+
kwargs = {
|
| 135 |
+
'num_workers': num_workers,
|
| 136 |
+
'pin_memory': True
|
| 137 |
+
} if use_cuda else {}
|
| 138 |
+
|
| 139 |
+
# Set up data loaders
|
| 140 |
+
#print(args.batch_size, args.eval_batch_size)
|
| 141 |
+
train_loader = torch.utils.data.DataLoader(data_train,
|
| 142 |
+
batch_size=args.batch_size,
|
| 143 |
+
shuffle=True, **kwargs)
|
| 144 |
+
val_loader = torch.utils.data.DataLoader(data_val,
|
| 145 |
+
batch_size=args.eval_batch_size,
|
| 146 |
+
**kwargs)
|
| 147 |
+
|
| 148 |
+
# Set up model
|
| 149 |
+
model = network.Net(**args.model_params)
|
| 150 |
+
|
| 151 |
+
# Add graph to tensorboard with example train samples
|
| 152 |
+
# _mixed, _label, _ = next(iter(val_loader))
|
| 153 |
+
# args.writer.add_graph(model, (_mixed, _label))
|
| 154 |
+
|
| 155 |
+
if use_cuda and data_parallel:
|
| 156 |
+
model = nn.DataParallel(model, device_ids=device_ids)
|
| 157 |
+
logging.info("Using data parallel model")
|
| 158 |
+
model.to(device)
|
| 159 |
+
|
| 160 |
+
# Set up the optimizer
|
| 161 |
+
logging.info("Initializing optimizer with %s" % str(args.optim))
|
| 162 |
+
optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel)
|
| 163 |
+
logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer))
|
| 164 |
+
|
| 165 |
+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 166 |
+
optimizer, **args.lr_sched)
|
| 167 |
+
logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s"
|
| 168 |
+
% (args.fix_lr_epochs, str(args.lr_sched)))
|
| 169 |
+
|
| 170 |
+
base_metric = args.base_metric
|
| 171 |
+
train_metrics = {}
|
| 172 |
+
val_metrics = {}
|
| 173 |
+
|
| 174 |
+
# Load the model if `args.start_epoch` is greater than 0. This will load the
|
| 175 |
+
# model from epoch = `args.start_epoch - 1`
|
| 176 |
+
assert args.start_epoch >=0, "start_epoch must be greater than 0."
|
| 177 |
+
if args.start_epoch > 0:
|
| 178 |
+
checkpoint_path = os.path.join(args.exp_dir,
|
| 179 |
+
'%d.pt' % (args.start_epoch - 1))
|
| 180 |
+
_, train_metrics, val_metrics = utils.load_checkpoint(
|
| 181 |
+
checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler,
|
| 182 |
+
data_parallel=data_parallel)
|
| 183 |
+
logging.info("Loaded checkpoint from %s" % checkpoint_path)
|
| 184 |
+
logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer))
|
| 185 |
+
|
| 186 |
+
# Training loop
|
| 187 |
+
try:
|
| 188 |
+
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
| 189 |
+
for epoch in range(args.start_epoch, args.epochs + 1):
|
| 190 |
+
logging.info("Epoch %d:" % epoch)
|
| 191 |
+
checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch)
|
| 192 |
+
assert not os.path.exists(checkpoint_file), \
|
| 193 |
+
"Checkpoint file %s already exists" % checkpoint_file
|
| 194 |
+
#print("---- begin trianivg")
|
| 195 |
+
curr_train_metrics = train_epoch(model, device, optimizer,
|
| 196 |
+
train_loader, args.n_train_items,
|
| 197 |
+
epoch=epoch, writer=args.writer,
|
| 198 |
+
data_params=args.train_data)
|
| 199 |
+
#raise KeyboardInterrupt
|
| 200 |
+
curr_test_metrics = test_epoch(model, device, val_loader,
|
| 201 |
+
args.n_test_items, network.loss,
|
| 202 |
+
network.metrics, epoch=epoch,
|
| 203 |
+
writer=args.writer,
|
| 204 |
+
data_params=args.val_data)
|
| 205 |
+
# LR scheduler
|
| 206 |
+
if epoch >= args.fix_lr_epochs:
|
| 207 |
+
lr_scheduler.step(curr_test_metrics[base_metric])
|
| 208 |
+
logging.info(
|
| 209 |
+
"LR after scheduling step: %s" %
|
| 210 |
+
[_['lr'] for _ in optimizer.param_groups])
|
| 211 |
+
|
| 212 |
+
# Write metrics to tensorboard
|
| 213 |
+
args.writer.add_scalars('Train', curr_train_metrics, epoch)
|
| 214 |
+
args.writer.add_scalars('Val', curr_test_metrics, epoch)
|
| 215 |
+
args.writer.flush()
|
| 216 |
+
|
| 217 |
+
for k in curr_train_metrics.keys():
|
| 218 |
+
if not k in train_metrics:
|
| 219 |
+
train_metrics[k] = [curr_train_metrics[k]]
|
| 220 |
+
else:
|
| 221 |
+
train_metrics[k].append(curr_train_metrics[k])
|
| 222 |
+
|
| 223 |
+
for k in curr_test_metrics.keys():
|
| 224 |
+
if not k in val_metrics:
|
| 225 |
+
val_metrics[k] = [curr_test_metrics[k]]
|
| 226 |
+
else:
|
| 227 |
+
val_metrics[k].append(curr_test_metrics[k])
|
| 228 |
+
|
| 229 |
+
if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]:
|
| 230 |
+
logging.info("Found best validation %s!" % base_metric)
|
| 231 |
+
|
| 232 |
+
utils.save_checkpoint(
|
| 233 |
+
checkpoint_file, epoch, model, optimizer, lr_scheduler,
|
| 234 |
+
train_metrics, val_metrics, data_parallel)
|
| 235 |
+
logging.info("Saved checkpoint at %s" % checkpoint_file)
|
| 236 |
+
|
| 237 |
+
utils.save_graph(train_metrics, val_metrics, args.exp_dir)
|
| 238 |
+
|
| 239 |
+
return train_metrics, val_metrics
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
except KeyboardInterrupt:
|
| 243 |
+
print("Interrupted")
|
| 244 |
+
except Exception as _: # pylint: disable=broad-except
|
| 245 |
+
import traceback # pylint: disable=import-outside-toplevel
|
| 246 |
+
traceback.print_exc()
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == '__main__':
|
| 250 |
+
parser = argparse.ArgumentParser()
|
| 251 |
+
# Data Params
|
| 252 |
+
parser.add_argument('exp_dir', type=str,
|
| 253 |
+
default='./experiments/fsd_mask_label_mult',
|
| 254 |
+
help="Path to save checkpoints and logs.")
|
| 255 |
+
|
| 256 |
+
parser.add_argument('--n_train_items', type=int, default=None,
|
| 257 |
+
help="Number of items to train on in each epoch")
|
| 258 |
+
parser.add_argument('--n_test_items', type=int, default=None,
|
| 259 |
+
help="Number of items to test.")
|
| 260 |
+
parser.add_argument('--start_epoch', type=int, default=0,
|
| 261 |
+
help="Start epoch")
|
| 262 |
+
parser.add_argument('--pretrain_path', type=str,
|
| 263 |
+
help="Path to pretrained weights")
|
| 264 |
+
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
| 265 |
+
help="Whether to use cuda")
|
| 266 |
+
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
| 267 |
+
help="List of GPU ids used for training. "
|
| 268 |
+
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
| 269 |
+
parser.add_argument('--detect_anomaly', dest='detect_anomaly',
|
| 270 |
+
action='store_true',
|
| 271 |
+
help="Whether to use cuda")
|
| 272 |
+
parser.add_argument('--wandb', dest='wandb', action='store_true',
|
| 273 |
+
help="Whether to sync tensorboard to wandb")
|
| 274 |
+
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
# Set the random seed for reproducible experiments
|
| 278 |
+
torch.manual_seed(230)
|
| 279 |
+
random.seed(230)
|
| 280 |
+
np.random.seed(230)
|
| 281 |
+
if args.use_cuda:
|
| 282 |
+
torch.cuda.manual_seed(230)
|
| 283 |
+
|
| 284 |
+
# Set up checkpoints
|
| 285 |
+
if not os.path.exists(args.exp_dir):
|
| 286 |
+
os.makedirs(args.exp_dir)
|
| 287 |
+
|
| 288 |
+
utils.set_logger(os.path.join(args.exp_dir, 'train.log'))
|
| 289 |
+
|
| 290 |
+
# Load model and training params
|
| 291 |
+
params = utils.Params(os.path.join(args.exp_dir, 'config.json'))
|
| 292 |
+
for k, v in params.__dict__.items():
|
| 293 |
+
vars(args)[k] = v
|
| 294 |
+
|
| 295 |
+
# Initialize tensorboard writer
|
| 296 |
+
tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard')
|
| 297 |
+
args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch)
|
| 298 |
+
if args.wandb:
|
| 299 |
+
import wandb
|
| 300 |
+
wandb.init(
|
| 301 |
+
project='Semaudio', sync_tensorboard=True,
|
| 302 |
+
dir=tensorboard_dir, name=os.path.basename(args.exp_dir))
|
| 303 |
+
|
| 304 |
+
exec("import %s as network" % args.model)
|
| 305 |
+
logging.info("Imported the model from '%s'." % args.model)
|
| 306 |
+
|
| 307 |
+
train(args)
|
| 308 |
+
|
| 309 |
+
args.writer.close()
|
| 310 |
+
if args.wandb:
|
| 311 |
+
wandb.finish()
|