Spaces:
Runtime error
Runtime error
Refactored source
Browse files- app.py +1 -2
- src/training/dcc_tf.py → dcc_tf.py +0 -0
- src/__init__.py +0 -0
- src/helpers/__init__.py +0 -0
- src/helpers/utils.py +0 -205
- src/training/__init__.py +0 -0
- src/training/eval.py +0 -214
- src/training/synthetic_dataset.py +0 -168
- src/training/train.py +0 -311
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import torch
|
|
| 6 |
import torchaudio
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
-
from src.helpers import utils
|
| 10 |
from src.training.dcc_tf import Net as Waveformer
|
| 11 |
|
| 12 |
TARGETS = [
|
|
@@ -34,7 +33,7 @@ if not os.path.exists('default_ckpt.pt'):
|
|
| 34 |
# Instantiate model
|
| 35 |
params = utils.Params('default_config.json')
|
| 36 |
model = Waveformer(**params.model_params)
|
| 37 |
-
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
def waveformer(audio, label_choices):
|
|
|
|
| 6 |
import torchaudio
|
| 7 |
import gradio as gr
|
| 8 |
|
|
|
|
| 9 |
from src.training.dcc_tf import Net as Waveformer
|
| 10 |
|
| 11 |
TARGETS = [
|
|
|
|
| 33 |
# Instantiate model
|
| 34 |
params = utils.Params('default_config.json')
|
| 35 |
model = Waveformer(**params.model_params)
|
| 36 |
+
model.load_state_dict(torch.load('default_ckpt.pt', map_location=torch.device('cpu')))
|
| 37 |
model.eval()
|
| 38 |
|
| 39 |
def waveformer(audio, label_choices):
|
src/training/dcc_tf.py → dcc_tf.py
RENAMED
|
File without changes
|
src/__init__.py
DELETED
|
File without changes
|
src/helpers/__init__.py
DELETED
|
File without changes
|
src/helpers/utils.py
DELETED
|
@@ -1,205 +0,0 @@
|
|
| 1 |
-
"""A collection of useful helper functions"""
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import logging
|
| 5 |
-
import json
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from torch.profiler import profile, record_function, ProfilerActivity
|
| 9 |
-
import pandas as pd
|
| 10 |
-
from torchmetrics.functional import(
|
| 11 |
-
scale_invariant_signal_noise_ratio as si_snr,
|
| 12 |
-
signal_noise_ratio as snr,
|
| 13 |
-
signal_distortion_ratio as sdr,
|
| 14 |
-
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 15 |
-
import matplotlib.pyplot as plt
|
| 16 |
-
|
| 17 |
-
class Params():
|
| 18 |
-
"""Class that loads hyperparameters from a json file.
|
| 19 |
-
Example:
|
| 20 |
-
```
|
| 21 |
-
params = Params(json_path)
|
| 22 |
-
print(params.learning_rate)
|
| 23 |
-
params.learning_rate = 0.5 # change the value of learning_rate in params
|
| 24 |
-
```
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
def __init__(self, json_path):
|
| 28 |
-
with open(json_path) as f:
|
| 29 |
-
params = json.load(f)
|
| 30 |
-
self.__dict__.update(params)
|
| 31 |
-
|
| 32 |
-
def save(self, json_path):
|
| 33 |
-
with open(json_path, 'w') as f:
|
| 34 |
-
json.dump(self.__dict__, f, indent=4)
|
| 35 |
-
|
| 36 |
-
def update(self, json_path):
|
| 37 |
-
"""Loads parameters from json file"""
|
| 38 |
-
with open(json_path) as f:
|
| 39 |
-
params = json.load(f)
|
| 40 |
-
self.__dict__.update(params)
|
| 41 |
-
|
| 42 |
-
@property
|
| 43 |
-
def dict(self):
|
| 44 |
-
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
|
| 45 |
-
return self.__dict__
|
| 46 |
-
|
| 47 |
-
def save_graph(train_metrics, test_metrics, save_dir):
|
| 48 |
-
metrics = [snr, si_snr]
|
| 49 |
-
results = {'train_loss': train_metrics['loss'],
|
| 50 |
-
'test_loss' : test_metrics['loss']}
|
| 51 |
-
|
| 52 |
-
for m_fn in metrics:
|
| 53 |
-
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
|
| 54 |
-
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
|
| 55 |
-
|
| 56 |
-
results_pd = pd.DataFrame(results)
|
| 57 |
-
|
| 58 |
-
results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
|
| 59 |
-
|
| 60 |
-
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
|
| 61 |
-
axs=[]
|
| 62 |
-
for i in temp_ax:
|
| 63 |
-
for j in i:
|
| 64 |
-
axs.append(j)
|
| 65 |
-
|
| 66 |
-
x = range(len(train_metrics['loss']))
|
| 67 |
-
axs[0].plot(x, train_metrics['loss'], label='train')
|
| 68 |
-
axs[0].plot(x, test_metrics['loss'], label='test')
|
| 69 |
-
axs[0].set(ylabel='Loss')
|
| 70 |
-
axs[0].set(xlabel='Epoch')
|
| 71 |
-
axs[0].set_title('loss',fontweight='bold')
|
| 72 |
-
axs[0].legend()
|
| 73 |
-
|
| 74 |
-
for i in range(len(metrics)):
|
| 75 |
-
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
|
| 76 |
-
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
|
| 77 |
-
axs[i+1].set(xlabel='Epoch')
|
| 78 |
-
axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
|
| 79 |
-
axs[i+1].legend()
|
| 80 |
-
|
| 81 |
-
plt.tight_layout()
|
| 82 |
-
plt.savefig(os.path.join(save_dir, 'results.png'))
|
| 83 |
-
plt.close(fig)
|
| 84 |
-
|
| 85 |
-
def set_logger(log_path):
|
| 86 |
-
"""Set the logger to log info in terminal and file `log_path`.
|
| 87 |
-
In general, it is useful to have a logger so that every output to the terminal is saved
|
| 88 |
-
in a permanent file. Here we save it to `model_dir/train.log`.
|
| 89 |
-
Example:
|
| 90 |
-
```
|
| 91 |
-
logging.info("Starting training...")
|
| 92 |
-
```
|
| 93 |
-
Args:
|
| 94 |
-
log_path: (string) where to log
|
| 95 |
-
"""
|
| 96 |
-
logger = logging.getLogger()
|
| 97 |
-
logger.setLevel(logging.INFO)
|
| 98 |
-
logger.handlers.clear()
|
| 99 |
-
|
| 100 |
-
# Logging to a file
|
| 101 |
-
file_handler = logging.FileHandler(log_path)
|
| 102 |
-
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
| 103 |
-
logger.addHandler(file_handler)
|
| 104 |
-
|
| 105 |
-
# Logging to console
|
| 106 |
-
stream_handler = logging.StreamHandler()
|
| 107 |
-
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 108 |
-
logger.addHandler(stream_handler)
|
| 109 |
-
|
| 110 |
-
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
|
| 111 |
-
"""Loads model parameters (state_dict) from file_path.
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
checkpoint: (string) filename which needs to be loaded
|
| 115 |
-
model: (torch.nn.Module) model for which the parameters are loaded
|
| 116 |
-
data_parallel: (bool) if the model is a data parallel model
|
| 117 |
-
"""
|
| 118 |
-
if not os.path.exists(checkpoint):
|
| 119 |
-
raise("File doesn't exist {}".format(checkpoint))
|
| 120 |
-
|
| 121 |
-
state_dict = torch.load(checkpoint)
|
| 122 |
-
|
| 123 |
-
if data_parallel:
|
| 124 |
-
state_dict['model_state_dict'] = {
|
| 125 |
-
'module.' + k: state_dict['model_state_dict'][k]
|
| 126 |
-
for k in state_dict['model_state_dict'].keys()}
|
| 127 |
-
model.load_state_dict(state_dict['model_state_dict'])
|
| 128 |
-
|
| 129 |
-
if optim is not None:
|
| 130 |
-
optim.load_state_dict(state_dict['optim_state_dict'])
|
| 131 |
-
|
| 132 |
-
if lr_sched is not None:
|
| 133 |
-
lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
|
| 134 |
-
|
| 135 |
-
return state_dict['epoch'], state_dict['train_metrics'], \
|
| 136 |
-
state_dict['val_metrics']
|
| 137 |
-
|
| 138 |
-
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
|
| 139 |
-
train_metrics=None, val_metrics=None, data_parallel=False):
|
| 140 |
-
"""Saves model parameters (state_dict) to file_path.
|
| 141 |
-
|
| 142 |
-
Args:
|
| 143 |
-
checkpoint: (string) filename which needs to be loaded
|
| 144 |
-
model: (torch.nn.Module) model for which the parameters are loaded
|
| 145 |
-
data_parallel: (bool) if the model is a data parallel model
|
| 146 |
-
"""
|
| 147 |
-
if os.path.exists(checkpoint):
|
| 148 |
-
raise("File already exists {}".format(checkpoint))
|
| 149 |
-
|
| 150 |
-
model_state_dict = model.state_dict()
|
| 151 |
-
if data_parallel:
|
| 152 |
-
model_state_dict = {
|
| 153 |
-
k.partition('module.')[2]:
|
| 154 |
-
model_state_dict[k] for k in model_state_dict.keys()}
|
| 155 |
-
|
| 156 |
-
optim_state_dict = None if not optim else optim.state_dict()
|
| 157 |
-
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
|
| 158 |
-
|
| 159 |
-
state_dict = {
|
| 160 |
-
'epoch': epoch,
|
| 161 |
-
'model_state_dict': model_state_dict,
|
| 162 |
-
'optim_state_dict': optim_state_dict,
|
| 163 |
-
'lr_sched_state_dict': lr_sched_state_dict,
|
| 164 |
-
'train_metrics': train_metrics,
|
| 165 |
-
'val_metrics': val_metrics
|
| 166 |
-
}
|
| 167 |
-
|
| 168 |
-
torch.save(state_dict, checkpoint)
|
| 169 |
-
|
| 170 |
-
def model_size(model):
|
| 171 |
-
"""
|
| 172 |
-
Returns size of the `model` in millions of parameters.
|
| 173 |
-
"""
|
| 174 |
-
num_train_params = sum(
|
| 175 |
-
p.numel() for p in model.parameters() if p.requires_grad)
|
| 176 |
-
return num_train_params / 1e6
|
| 177 |
-
|
| 178 |
-
def run_time(model, inputs, profiling=False):
|
| 179 |
-
"""
|
| 180 |
-
Returns runtime of a model in ms.
|
| 181 |
-
"""
|
| 182 |
-
# Warmup
|
| 183 |
-
for _ in range(100):
|
| 184 |
-
output = model(*inputs)
|
| 185 |
-
|
| 186 |
-
with profile(activities=[ProfilerActivity.CPU],
|
| 187 |
-
record_shapes=True) as prof:
|
| 188 |
-
with record_function("model_inference"):
|
| 189 |
-
output = model(*inputs)
|
| 190 |
-
|
| 191 |
-
# Print profiling results
|
| 192 |
-
if profiling:
|
| 193 |
-
print(prof.key_averages().table(sort_by="self_cpu_time_total",
|
| 194 |
-
row_limit=20))
|
| 195 |
-
|
| 196 |
-
# Return runtime in ms
|
| 197 |
-
return prof.profiler.self_cpu_time_total / 1000
|
| 198 |
-
|
| 199 |
-
def format_lr_info(optimizer):
|
| 200 |
-
lr_info = ""
|
| 201 |
-
for i, pg in enumerate(optimizer.param_groups):
|
| 202 |
-
lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
|
| 203 |
-
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
|
| 204 |
-
return lr_info
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/__init__.py
DELETED
|
File without changes
|
src/training/eval.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Test script to evaluate the model.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
import importlib
|
| 7 |
-
import multiprocessing
|
| 8 |
-
import os, glob
|
| 9 |
-
import logging
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
import torch
|
| 13 |
-
import pandas as pd
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
-
from torch.profiler import profile, record_function, ProfilerActivity
|
| 17 |
-
from tqdm import tqdm # pylint: disable=unused-import
|
| 18 |
-
from torchmetrics.functional import(
|
| 19 |
-
scale_invariant_signal_noise_ratio as si_snr,
|
| 20 |
-
signal_noise_ratio as snr,
|
| 21 |
-
signal_distortion_ratio as sdr,
|
| 22 |
-
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 23 |
-
|
| 24 |
-
from src.helpers import utils
|
| 25 |
-
from src.training.synthetic_dataset import FSDSoundScapesDataset, tensorboard_add_metrics
|
| 26 |
-
from src.training.synthetic_dataset import tensorboard_add_sample
|
| 27 |
-
|
| 28 |
-
def test_epoch(model: nn.Module, device: torch.device,
|
| 29 |
-
test_loader: torch.utils.data.dataloader.DataLoader,
|
| 30 |
-
n_items: int, loss_fn, metrics_fn,
|
| 31 |
-
profiling: bool = False, epoch: int = 0,
|
| 32 |
-
writer: SummaryWriter = None, data_params = None) -> float:
|
| 33 |
-
"""
|
| 34 |
-
Evaluate the network.
|
| 35 |
-
"""
|
| 36 |
-
model.eval()
|
| 37 |
-
metrics = {}
|
| 38 |
-
|
| 39 |
-
with torch.no_grad():
|
| 40 |
-
for batch_idx, (mixed, label, gt) in \
|
| 41 |
-
enumerate(tqdm(test_loader, desc='Test', ncols=100)):
|
| 42 |
-
mixed = mixed.to(device)
|
| 43 |
-
label = label.to(device)
|
| 44 |
-
gt = gt.to(device)
|
| 45 |
-
|
| 46 |
-
# Run through the model
|
| 47 |
-
with profile(activities=[ProfilerActivity.CPU],
|
| 48 |
-
record_shapes=True) as prof:
|
| 49 |
-
with record_function("model_inference"):
|
| 50 |
-
output = model(mixed, label)
|
| 51 |
-
if profiling:
|
| 52 |
-
logging.info(
|
| 53 |
-
prof.key_averages().table(sort_by="self_cpu_time_total",
|
| 54 |
-
row_limit=20))
|
| 55 |
-
|
| 56 |
-
# Compute loss
|
| 57 |
-
loss = loss_fn(output, gt)
|
| 58 |
-
|
| 59 |
-
# Compute metrics
|
| 60 |
-
metrics_batch = metrics_fn(mixed, output, gt)
|
| 61 |
-
metrics_batch['loss'] = [loss.item()]
|
| 62 |
-
metrics_batch['runtime'] = [prof.profiler.self_cpu_time_total/1000]
|
| 63 |
-
for k in metrics_batch.keys():
|
| 64 |
-
if not k in metrics:
|
| 65 |
-
metrics[k] = metrics_batch[k]
|
| 66 |
-
else:
|
| 67 |
-
metrics[k] += metrics_batch[k]
|
| 68 |
-
|
| 69 |
-
if writer is not None:
|
| 70 |
-
if batch_idx == 0:
|
| 71 |
-
tensorboard_add_sample(
|
| 72 |
-
writer, tag='Test',
|
| 73 |
-
sample=(mixed[:8], label[:8], gt[:8], output[:8]),
|
| 74 |
-
step=epoch, params=data_params)
|
| 75 |
-
tensorboard_add_metrics(
|
| 76 |
-
writer, tag='Test', metrics=metrics_batch, label=label,
|
| 77 |
-
step=epoch)
|
| 78 |
-
|
| 79 |
-
if n_items is not None and batch_idx == (n_items - 1):
|
| 80 |
-
break
|
| 81 |
-
|
| 82 |
-
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
| 83 |
-
avg_metrics_str = "Test:"
|
| 84 |
-
for m in avg_metrics.keys():
|
| 85 |
-
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
| 86 |
-
logging.info(avg_metrics_str)
|
| 87 |
-
|
| 88 |
-
return avg_metrics
|
| 89 |
-
|
| 90 |
-
def evaluate(network, args: argparse.Namespace):
|
| 91 |
-
"""
|
| 92 |
-
Evaluate the model on a given dataset.
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
# Load dataset
|
| 96 |
-
data_test = FSDSoundScapesDataset(**args.test_data)
|
| 97 |
-
logging.info("Loaded test dataset at %s containing %d elements" %
|
| 98 |
-
(args.test_data['input_dir'], len(data_test)))
|
| 99 |
-
|
| 100 |
-
# Set up the device and workers.
|
| 101 |
-
use_cuda = args.use_cuda and torch.cuda.is_available()
|
| 102 |
-
if use_cuda:
|
| 103 |
-
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
| 104 |
-
else range(torch.cuda.device_count())
|
| 105 |
-
device_ids = [_ for _ in gpu_ids]
|
| 106 |
-
data_parallel = len(device_ids) > 1
|
| 107 |
-
device = 'cuda:%d' % device_ids[0]
|
| 108 |
-
torch.cuda.set_device(device_ids[0])
|
| 109 |
-
logging.info("Using CUDA devices: %s" % str(device_ids))
|
| 110 |
-
else:
|
| 111 |
-
data_parallel = False
|
| 112 |
-
device = torch.device('cpu')
|
| 113 |
-
logging.info("Using device: CPU")
|
| 114 |
-
|
| 115 |
-
# Set multiprocessing params
|
| 116 |
-
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
| 117 |
-
kwargs = {
|
| 118 |
-
'num_workers': num_workers,
|
| 119 |
-
'pin_memory': True
|
| 120 |
-
} if use_cuda else {}
|
| 121 |
-
|
| 122 |
-
# Set up data loader
|
| 123 |
-
test_loader = torch.utils.data.DataLoader(data_test,
|
| 124 |
-
batch_size=args.eval_batch_size,
|
| 125 |
-
**kwargs)
|
| 126 |
-
|
| 127 |
-
# Set up model
|
| 128 |
-
model = network.Net(**args.model_params)
|
| 129 |
-
if use_cuda and data_parallel:
|
| 130 |
-
model = nn.DataParallel(model, device_ids=device_ids)
|
| 131 |
-
logging.info("Using data parallel model")
|
| 132 |
-
model.to(device)
|
| 133 |
-
|
| 134 |
-
# Load weights
|
| 135 |
-
if args.pretrain_path == "best":
|
| 136 |
-
ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt'))
|
| 137 |
-
ckpts.sort(
|
| 138 |
-
key=lambda _: int(os.path.splitext(os.path.basename(_))[0]))
|
| 139 |
-
val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric]
|
| 140 |
-
best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__)
|
| 141 |
-
args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch)
|
| 142 |
-
logging.info(
|
| 143 |
-
"Found 'best' validation %s=%.02f at %s" %
|
| 144 |
-
(args.base_metric, val_metrics[best_epoch], args.pretrain_path))
|
| 145 |
-
if args.pretrain_path != "":
|
| 146 |
-
utils.load_checkpoint(
|
| 147 |
-
args.pretrain_path, model, data_parallel=data_parallel)
|
| 148 |
-
logging.info("Loaded pretrain weights from %s" % args.pretrain_path)
|
| 149 |
-
|
| 150 |
-
# Evaluate
|
| 151 |
-
try:
|
| 152 |
-
return test_epoch(
|
| 153 |
-
model, device, test_loader, args.n_items, network.loss,
|
| 154 |
-
network.metrics, args.profiling)
|
| 155 |
-
except KeyboardInterrupt:
|
| 156 |
-
print("Interrupted")
|
| 157 |
-
except Exception as _: # pylint: disable=broad-except
|
| 158 |
-
import traceback # pylint: disable=import-outside-toplevel
|
| 159 |
-
traceback.print_exc()
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
if __name__ == '__main__':
|
| 163 |
-
parser = argparse.ArgumentParser()
|
| 164 |
-
# Data Params
|
| 165 |
-
parser.add_argument('experiments', nargs='+', type=str,
|
| 166 |
-
default=None,
|
| 167 |
-
help="List of experiments to evaluate. "
|
| 168 |
-
"Provide only one experiment when providing "
|
| 169 |
-
"pretrained path. If pretrianed path is not "
|
| 170 |
-
"provided, epoch with best validation metric "
|
| 171 |
-
"is used for evaluation.")
|
| 172 |
-
parser.add_argument('--results', type=str, default="",
|
| 173 |
-
help="Path to the CSV file to store results.")
|
| 174 |
-
|
| 175 |
-
# System params
|
| 176 |
-
parser.add_argument('--n_items', type=int, default=None,
|
| 177 |
-
help="Number of items to test.")
|
| 178 |
-
parser.add_argument('--pretrain_path', type=str, default="best",
|
| 179 |
-
help="Path to pretrained weights")
|
| 180 |
-
parser.add_argument('--profiling', dest='profiling', action='store_true',
|
| 181 |
-
help="Enable or disable profiling.")
|
| 182 |
-
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
| 183 |
-
help="Whether to use cuda")
|
| 184 |
-
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
| 185 |
-
help="List of GPU ids used for training. "
|
| 186 |
-
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
| 187 |
-
args = parser.parse_args()
|
| 188 |
-
|
| 189 |
-
results = []
|
| 190 |
-
|
| 191 |
-
for exp_dir in args.experiments:
|
| 192 |
-
eval_args = argparse.Namespace(**vars(args))
|
| 193 |
-
eval_args.exp_dir = exp_dir
|
| 194 |
-
|
| 195 |
-
utils.set_logger(os.path.join(exp_dir, 'eval.log'))
|
| 196 |
-
logging.info("Evaluating %s ..." % exp_dir)
|
| 197 |
-
|
| 198 |
-
# Load model and training params
|
| 199 |
-
params = utils.Params(os.path.join(exp_dir, 'config.json'))
|
| 200 |
-
for k, v in params.__dict__.items():
|
| 201 |
-
vars(eval_args)[k] = v
|
| 202 |
-
|
| 203 |
-
network = importlib.import_module(eval_args.model)
|
| 204 |
-
logging.info("Imported the model from '%s'." % eval_args.model)
|
| 205 |
-
|
| 206 |
-
curr_res = evaluate(network, eval_args)
|
| 207 |
-
curr_res['experiment'] = os.path.basename(exp_dir)
|
| 208 |
-
results.append(curr_res)
|
| 209 |
-
|
| 210 |
-
del eval_args
|
| 211 |
-
|
| 212 |
-
if args.results != "":
|
| 213 |
-
print("Writing results to %s" % args.results)
|
| 214 |
-
pd.DataFrame(results).to_csv(args.results, index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/synthetic_dataset.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Torch dataset object for synthetically rendered spatial data.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import json
|
| 7 |
-
import random
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
import logging
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
import pandas as pd
|
| 13 |
-
import matplotlib.pyplot as plt
|
| 14 |
-
import scaper
|
| 15 |
-
import torch
|
| 16 |
-
import torchaudio
|
| 17 |
-
import torchaudio.transforms as AT
|
| 18 |
-
from random import randrange
|
| 19 |
-
|
| 20 |
-
class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore
|
| 21 |
-
"""
|
| 22 |
-
Base class for FSD Sound Scapes dataset
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
_labels = [
|
| 26 |
-
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
| 27 |
-
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
|
| 28 |
-
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
|
| 29 |
-
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
|
| 30 |
-
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
|
| 31 |
-
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
|
| 32 |
-
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
|
| 33 |
-
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
|
| 34 |
-
"Trumpet", "Violin_or_fiddle", "Writing"]
|
| 35 |
-
|
| 36 |
-
def __init__(self, input_dir, dset='', sr=None,
|
| 37 |
-
resample_rate=None, max_num_targets=1):
|
| 38 |
-
assert dset in ['train', 'val', 'test'], \
|
| 39 |
-
"`dset` must be one of ['train', 'val', 'test']"
|
| 40 |
-
self.dset = dset
|
| 41 |
-
self.max_num_targets = max_num_targets
|
| 42 |
-
self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset)
|
| 43 |
-
if dset in ['train', 'val']:
|
| 44 |
-
self.bg_dir = os.path.join(
|
| 45 |
-
input_dir,
|
| 46 |
-
'TAU-acoustic-sounds/'
|
| 47 |
-
'TAU-urban-acoustic-scenes-2019-development')
|
| 48 |
-
else:
|
| 49 |
-
self.bg_dir = os.path.join(
|
| 50 |
-
input_dir,
|
| 51 |
-
'TAU-acoustic-sounds/'
|
| 52 |
-
'TAU-urban-acoustic-scenes-2019-evaluation')
|
| 53 |
-
logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" %
|
| 54 |
-
(dset, self.fg_dir, self.bg_dir))
|
| 55 |
-
|
| 56 |
-
self.samples = sorted(list(
|
| 57 |
-
Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*')))
|
| 58 |
-
|
| 59 |
-
jamsfile = os.path.join(self.samples[0], 'mixture.jams')
|
| 60 |
-
_, jams, _, _ = scaper.generate_from_jams(
|
| 61 |
-
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
| 62 |
-
_sr = jams['annotations'][0]['sandbox']['scaper']['sr']
|
| 63 |
-
assert _sr == sr, "Sampling rate provided does not match the data"
|
| 64 |
-
|
| 65 |
-
if resample_rate is not None:
|
| 66 |
-
self.resampler = AT.Resample(sr, resample_rate)
|
| 67 |
-
self.sr = resample_rate
|
| 68 |
-
else:
|
| 69 |
-
self.resampler = lambda a: a
|
| 70 |
-
self.sr = sr
|
| 71 |
-
|
| 72 |
-
def _get_label_vector(self, labels):
|
| 73 |
-
"""
|
| 74 |
-
Generates a multi-hot vector corresponding to `labels`.
|
| 75 |
-
"""
|
| 76 |
-
vector = torch.zeros(len(FSDSoundScapesDataset._labels))
|
| 77 |
-
|
| 78 |
-
for label in labels:
|
| 79 |
-
idx = FSDSoundScapesDataset._labels.index(label)
|
| 80 |
-
assert vector[idx] == 0, "Repeated labels"
|
| 81 |
-
vector[idx] = 1
|
| 82 |
-
|
| 83 |
-
return vector
|
| 84 |
-
|
| 85 |
-
def __len__(self):
|
| 86 |
-
return len(self.samples)
|
| 87 |
-
|
| 88 |
-
def __getitem__(self, idx):
|
| 89 |
-
sample_path = self.samples[idx]
|
| 90 |
-
jamsfile = os.path.join(sample_path, 'mixture.jams')
|
| 91 |
-
|
| 92 |
-
mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams(
|
| 93 |
-
jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
|
| 94 |
-
isolated_events = {}
|
| 95 |
-
for e, a in zip(ann_list, event_audio_list[1:]):
|
| 96 |
-
# 0th event is background
|
| 97 |
-
isolated_events[e[2]] = a
|
| 98 |
-
gt_events = list(pd.read_csv(
|
| 99 |
-
os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label'])
|
| 100 |
-
|
| 101 |
-
mixture = torch.from_numpy(mixture).permute(1, 0)
|
| 102 |
-
mixture = self.resampler(mixture.to(torch.float))
|
| 103 |
-
|
| 104 |
-
if self.dset == 'train':
|
| 105 |
-
labels = random.sample(gt_events, randrange(1,self.max_num_targets+1))
|
| 106 |
-
elif self.dset == 'val':
|
| 107 |
-
labels = gt_events[:idx%self.max_num_targets+1]
|
| 108 |
-
elif self.dset == 'test':
|
| 109 |
-
labels = gt_events[:self.max_num_targets]
|
| 110 |
-
label_vector = self._get_label_vector(labels)
|
| 111 |
-
|
| 112 |
-
gt = torch.zeros_like(
|
| 113 |
-
torch.from_numpy(event_audio_list[1]).permute(1, 0))
|
| 114 |
-
for l in labels:
|
| 115 |
-
gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0)
|
| 116 |
-
gt = self.resampler(gt.to(torch.float))
|
| 117 |
-
|
| 118 |
-
return mixture, label_vector, gt #, jams
|
| 119 |
-
|
| 120 |
-
def tensorboard_add_sample(writer, tag, sample, step, params):
|
| 121 |
-
"""
|
| 122 |
-
Adds a sample of FSDSynthDataset to tensorboard.
|
| 123 |
-
"""
|
| 124 |
-
if params['resample_rate'] is not None:
|
| 125 |
-
sr = params['resample_rate']
|
| 126 |
-
else:
|
| 127 |
-
sr = params['sr']
|
| 128 |
-
resample_rate = 16000 if sr > 16000 else sr
|
| 129 |
-
|
| 130 |
-
m, l, gt, o = sample
|
| 131 |
-
m, gt, o = (
|
| 132 |
-
torchaudio.functional.resample(_, sr, resample_rate).cpu()
|
| 133 |
-
for _ in (m, gt, o))
|
| 134 |
-
|
| 135 |
-
def _add_audio(a, audio_tag, axis, plt_title):
|
| 136 |
-
for i, ch in enumerate(a):
|
| 137 |
-
axis.plot(ch, label='mic %d' % i)
|
| 138 |
-
writer.add_audio(
|
| 139 |
-
'%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate)
|
| 140 |
-
axis.set_title(plt_title)
|
| 141 |
-
axis.legend()
|
| 142 |
-
|
| 143 |
-
for b in range(m.shape[0]):
|
| 144 |
-
label = []
|
| 145 |
-
for i in range(len(l[b, :])):
|
| 146 |
-
if l[b, i] == 1:
|
| 147 |
-
label.append(FSDSoundScapesDataset._labels[i])
|
| 148 |
-
|
| 149 |
-
# Add waveforms
|
| 150 |
-
rows = 3 # input, output, gt
|
| 151 |
-
fig = plt.figure(figsize=(10, 2 * rows))
|
| 152 |
-
axes = fig.subplots(rows, 1, sharex=True)
|
| 153 |
-
_add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed")
|
| 154 |
-
_add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label)
|
| 155 |
-
_add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label)
|
| 156 |
-
writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step)
|
| 157 |
-
|
| 158 |
-
def tensorboard_add_metrics(writer, tag, metrics, label, step):
|
| 159 |
-
"""
|
| 160 |
-
Add metrics to tensorboard.
|
| 161 |
-
"""
|
| 162 |
-
vals = np.asarray(metrics['scale_invariant_signal_noise_ratio'])
|
| 163 |
-
|
| 164 |
-
writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step)
|
| 165 |
-
|
| 166 |
-
label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label]
|
| 167 |
-
for l, v in zip(label_names, vals):
|
| 168 |
-
writer.add_histogram('%s/%s' % (tag, l), v, step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/train.py
DELETED
|
@@ -1,311 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The main training script for training on synthetic data
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
import multiprocessing
|
| 7 |
-
import os
|
| 8 |
-
import logging
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
import random
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
import torch.nn.functional as F
|
| 16 |
-
import torch.optim as optim
|
| 17 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 18 |
-
from tqdm import tqdm # pylint: disable=unused-import
|
| 19 |
-
from torchmetrics.functional import(
|
| 20 |
-
scale_invariant_signal_noise_ratio as si_snr,
|
| 21 |
-
signal_noise_ratio as snr,
|
| 22 |
-
signal_distortion_ratio as sdr,
|
| 23 |
-
scale_invariant_signal_distortion_ratio as si_sdr)
|
| 24 |
-
|
| 25 |
-
from src.helpers import utils
|
| 26 |
-
from src.training.eval import test_epoch
|
| 27 |
-
from src.training.synthetic_dataset import FSDSoundScapesDataset as Dataset
|
| 28 |
-
from src.training.synthetic_dataset import tensorboard_add_sample
|
| 29 |
-
|
| 30 |
-
def train_epoch(model: nn.Module, device: torch.device,
|
| 31 |
-
optimizer: optim.Optimizer,
|
| 32 |
-
train_loader: torch.utils.data.dataloader.DataLoader,
|
| 33 |
-
n_items: int, epoch: int = 0,
|
| 34 |
-
writer: SummaryWriter = None, data_params = None) -> float:
|
| 35 |
-
|
| 36 |
-
"""
|
| 37 |
-
Train a single epoch.
|
| 38 |
-
"""
|
| 39 |
-
# Set the model to training.
|
| 40 |
-
model.train()
|
| 41 |
-
|
| 42 |
-
# Training loop
|
| 43 |
-
losses = []
|
| 44 |
-
metrics = {}
|
| 45 |
-
|
| 46 |
-
with tqdm(total=len(train_loader), desc='Train', ncols=100) as t:
|
| 47 |
-
for batch_idx, (mixed, label, gt) in enumerate(train_loader):
|
| 48 |
-
mixed = mixed.to(device)
|
| 49 |
-
label = label.to(device)
|
| 50 |
-
gt = gt.to(device)
|
| 51 |
-
|
| 52 |
-
# Reset grad
|
| 53 |
-
optimizer.zero_grad()
|
| 54 |
-
|
| 55 |
-
# Run through the model
|
| 56 |
-
output = model(mixed, label)
|
| 57 |
-
|
| 58 |
-
# Compute loss
|
| 59 |
-
loss = network.loss(output, gt)
|
| 60 |
-
|
| 61 |
-
losses.append(loss.item())
|
| 62 |
-
|
| 63 |
-
# Backpropagation
|
| 64 |
-
loss.backward()
|
| 65 |
-
|
| 66 |
-
# Gradient clipping
|
| 67 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
| 68 |
-
|
| 69 |
-
# Update the weights
|
| 70 |
-
optimizer.step()
|
| 71 |
-
|
| 72 |
-
metrics_batch = network.metrics(mixed.detach(), output.detach(),
|
| 73 |
-
gt.detach())
|
| 74 |
-
for k in metrics_batch.keys():
|
| 75 |
-
if not k in metrics:
|
| 76 |
-
metrics[k] = metrics_batch[k]
|
| 77 |
-
else:
|
| 78 |
-
metrics[k] += metrics_batch[k]
|
| 79 |
-
|
| 80 |
-
if writer is not None and batch_idx == 0:
|
| 81 |
-
tensorboard_add_sample(
|
| 82 |
-
writer, tag='Train',
|
| 83 |
-
sample=(mixed.detach()[:8], label.detach()[:8],
|
| 84 |
-
gt.detach()[:8], output.detach()[:8]),
|
| 85 |
-
step=epoch, params=data_params)
|
| 86 |
-
|
| 87 |
-
# Show current loss in the progress meter
|
| 88 |
-
t.set_postfix(loss='%.05f'%loss.item())
|
| 89 |
-
t.update()
|
| 90 |
-
|
| 91 |
-
if n_items is not None and batch_idx == n_items:
|
| 92 |
-
break
|
| 93 |
-
|
| 94 |
-
avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
|
| 95 |
-
avg_metrics['loss'] = np.mean(losses)
|
| 96 |
-
avg_metrics_str = "Train:"
|
| 97 |
-
for m in avg_metrics.keys():
|
| 98 |
-
avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
|
| 99 |
-
logging.info(avg_metrics_str)
|
| 100 |
-
|
| 101 |
-
return avg_metrics
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def train(args: argparse.Namespace):
|
| 105 |
-
"""
|
| 106 |
-
Train the network.
|
| 107 |
-
"""
|
| 108 |
-
|
| 109 |
-
# Load dataset
|
| 110 |
-
data_train = Dataset(**args.train_data)
|
| 111 |
-
logging.info("Loaded train dataset at %s containing %d elements" %
|
| 112 |
-
(args.train_data['input_dir'], len(data_train)))
|
| 113 |
-
data_val = Dataset(**args.val_data)
|
| 114 |
-
logging.info("Loaded test dataset at %s containing %d elements" %
|
| 115 |
-
(args.val_data['input_dir'], len(data_val)))
|
| 116 |
-
|
| 117 |
-
# Set up the device and workers.
|
| 118 |
-
use_cuda = args.use_cuda and torch.cuda.is_available()
|
| 119 |
-
if use_cuda:
|
| 120 |
-
gpu_ids = args.gpu_ids if args.gpu_ids is not None\
|
| 121 |
-
else range(torch.cuda.device_count())
|
| 122 |
-
device_ids = [_ for _ in gpu_ids]
|
| 123 |
-
data_parallel = len(device_ids) > 1
|
| 124 |
-
device = 'cuda:%d' % device_ids[0]
|
| 125 |
-
torch.cuda.set_device(device_ids[0])
|
| 126 |
-
logging.info("Using CUDA devices: %s" % str(device_ids))
|
| 127 |
-
else:
|
| 128 |
-
data_parallel = False
|
| 129 |
-
device = torch.device('cpu')
|
| 130 |
-
logging.info("Using device: CPU")
|
| 131 |
-
|
| 132 |
-
# Set multiprocessing params
|
| 133 |
-
num_workers = min(multiprocessing.cpu_count(), args.n_workers)
|
| 134 |
-
kwargs = {
|
| 135 |
-
'num_workers': num_workers,
|
| 136 |
-
'pin_memory': True
|
| 137 |
-
} if use_cuda else {}
|
| 138 |
-
|
| 139 |
-
# Set up data loaders
|
| 140 |
-
#print(args.batch_size, args.eval_batch_size)
|
| 141 |
-
train_loader = torch.utils.data.DataLoader(data_train,
|
| 142 |
-
batch_size=args.batch_size,
|
| 143 |
-
shuffle=True, **kwargs)
|
| 144 |
-
val_loader = torch.utils.data.DataLoader(data_val,
|
| 145 |
-
batch_size=args.eval_batch_size,
|
| 146 |
-
**kwargs)
|
| 147 |
-
|
| 148 |
-
# Set up model
|
| 149 |
-
model = network.Net(**args.model_params)
|
| 150 |
-
|
| 151 |
-
# Add graph to tensorboard with example train samples
|
| 152 |
-
# _mixed, _label, _ = next(iter(val_loader))
|
| 153 |
-
# args.writer.add_graph(model, (_mixed, _label))
|
| 154 |
-
|
| 155 |
-
if use_cuda and data_parallel:
|
| 156 |
-
model = nn.DataParallel(model, device_ids=device_ids)
|
| 157 |
-
logging.info("Using data parallel model")
|
| 158 |
-
model.to(device)
|
| 159 |
-
|
| 160 |
-
# Set up the optimizer
|
| 161 |
-
logging.info("Initializing optimizer with %s" % str(args.optim))
|
| 162 |
-
optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel)
|
| 163 |
-
logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer))
|
| 164 |
-
|
| 165 |
-
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 166 |
-
optimizer, **args.lr_sched)
|
| 167 |
-
logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s"
|
| 168 |
-
% (args.fix_lr_epochs, str(args.lr_sched)))
|
| 169 |
-
|
| 170 |
-
base_metric = args.base_metric
|
| 171 |
-
train_metrics = {}
|
| 172 |
-
val_metrics = {}
|
| 173 |
-
|
| 174 |
-
# Load the model if `args.start_epoch` is greater than 0. This will load the
|
| 175 |
-
# model from epoch = `args.start_epoch - 1`
|
| 176 |
-
assert args.start_epoch >=0, "start_epoch must be greater than 0."
|
| 177 |
-
if args.start_epoch > 0:
|
| 178 |
-
checkpoint_path = os.path.join(args.exp_dir,
|
| 179 |
-
'%d.pt' % (args.start_epoch - 1))
|
| 180 |
-
_, train_metrics, val_metrics = utils.load_checkpoint(
|
| 181 |
-
checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler,
|
| 182 |
-
data_parallel=data_parallel)
|
| 183 |
-
logging.info("Loaded checkpoint from %s" % checkpoint_path)
|
| 184 |
-
logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer))
|
| 185 |
-
|
| 186 |
-
# Training loop
|
| 187 |
-
try:
|
| 188 |
-
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
| 189 |
-
for epoch in range(args.start_epoch, args.epochs + 1):
|
| 190 |
-
logging.info("Epoch %d:" % epoch)
|
| 191 |
-
checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch)
|
| 192 |
-
assert not os.path.exists(checkpoint_file), \
|
| 193 |
-
"Checkpoint file %s already exists" % checkpoint_file
|
| 194 |
-
#print("---- begin trianivg")
|
| 195 |
-
curr_train_metrics = train_epoch(model, device, optimizer,
|
| 196 |
-
train_loader, args.n_train_items,
|
| 197 |
-
epoch=epoch, writer=args.writer,
|
| 198 |
-
data_params=args.train_data)
|
| 199 |
-
#raise KeyboardInterrupt
|
| 200 |
-
curr_test_metrics = test_epoch(model, device, val_loader,
|
| 201 |
-
args.n_test_items, network.loss,
|
| 202 |
-
network.metrics, epoch=epoch,
|
| 203 |
-
writer=args.writer,
|
| 204 |
-
data_params=args.val_data)
|
| 205 |
-
# LR scheduler
|
| 206 |
-
if epoch >= args.fix_lr_epochs:
|
| 207 |
-
lr_scheduler.step(curr_test_metrics[base_metric])
|
| 208 |
-
logging.info(
|
| 209 |
-
"LR after scheduling step: %s" %
|
| 210 |
-
[_['lr'] for _ in optimizer.param_groups])
|
| 211 |
-
|
| 212 |
-
# Write metrics to tensorboard
|
| 213 |
-
args.writer.add_scalars('Train', curr_train_metrics, epoch)
|
| 214 |
-
args.writer.add_scalars('Val', curr_test_metrics, epoch)
|
| 215 |
-
args.writer.flush()
|
| 216 |
-
|
| 217 |
-
for k in curr_train_metrics.keys():
|
| 218 |
-
if not k in train_metrics:
|
| 219 |
-
train_metrics[k] = [curr_train_metrics[k]]
|
| 220 |
-
else:
|
| 221 |
-
train_metrics[k].append(curr_train_metrics[k])
|
| 222 |
-
|
| 223 |
-
for k in curr_test_metrics.keys():
|
| 224 |
-
if not k in val_metrics:
|
| 225 |
-
val_metrics[k] = [curr_test_metrics[k]]
|
| 226 |
-
else:
|
| 227 |
-
val_metrics[k].append(curr_test_metrics[k])
|
| 228 |
-
|
| 229 |
-
if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]:
|
| 230 |
-
logging.info("Found best validation %s!" % base_metric)
|
| 231 |
-
|
| 232 |
-
utils.save_checkpoint(
|
| 233 |
-
checkpoint_file, epoch, model, optimizer, lr_scheduler,
|
| 234 |
-
train_metrics, val_metrics, data_parallel)
|
| 235 |
-
logging.info("Saved checkpoint at %s" % checkpoint_file)
|
| 236 |
-
|
| 237 |
-
utils.save_graph(train_metrics, val_metrics, args.exp_dir)
|
| 238 |
-
|
| 239 |
-
return train_metrics, val_metrics
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
except KeyboardInterrupt:
|
| 243 |
-
print("Interrupted")
|
| 244 |
-
except Exception as _: # pylint: disable=broad-except
|
| 245 |
-
import traceback # pylint: disable=import-outside-toplevel
|
| 246 |
-
traceback.print_exc()
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
if __name__ == '__main__':
|
| 250 |
-
parser = argparse.ArgumentParser()
|
| 251 |
-
# Data Params
|
| 252 |
-
parser.add_argument('exp_dir', type=str,
|
| 253 |
-
default='./experiments/fsd_mask_label_mult',
|
| 254 |
-
help="Path to save checkpoints and logs.")
|
| 255 |
-
|
| 256 |
-
parser.add_argument('--n_train_items', type=int, default=None,
|
| 257 |
-
help="Number of items to train on in each epoch")
|
| 258 |
-
parser.add_argument('--n_test_items', type=int, default=None,
|
| 259 |
-
help="Number of items to test.")
|
| 260 |
-
parser.add_argument('--start_epoch', type=int, default=0,
|
| 261 |
-
help="Start epoch")
|
| 262 |
-
parser.add_argument('--pretrain_path', type=str,
|
| 263 |
-
help="Path to pretrained weights")
|
| 264 |
-
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
|
| 265 |
-
help="Whether to use cuda")
|
| 266 |
-
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
|
| 267 |
-
help="List of GPU ids used for training. "
|
| 268 |
-
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
|
| 269 |
-
parser.add_argument('--detect_anomaly', dest='detect_anomaly',
|
| 270 |
-
action='store_true',
|
| 271 |
-
help="Whether to use cuda")
|
| 272 |
-
parser.add_argument('--wandb', dest='wandb', action='store_true',
|
| 273 |
-
help="Whether to sync tensorboard to wandb")
|
| 274 |
-
|
| 275 |
-
args = parser.parse_args()
|
| 276 |
-
|
| 277 |
-
# Set the random seed for reproducible experiments
|
| 278 |
-
torch.manual_seed(230)
|
| 279 |
-
random.seed(230)
|
| 280 |
-
np.random.seed(230)
|
| 281 |
-
if args.use_cuda:
|
| 282 |
-
torch.cuda.manual_seed(230)
|
| 283 |
-
|
| 284 |
-
# Set up checkpoints
|
| 285 |
-
if not os.path.exists(args.exp_dir):
|
| 286 |
-
os.makedirs(args.exp_dir)
|
| 287 |
-
|
| 288 |
-
utils.set_logger(os.path.join(args.exp_dir, 'train.log'))
|
| 289 |
-
|
| 290 |
-
# Load model and training params
|
| 291 |
-
params = utils.Params(os.path.join(args.exp_dir, 'config.json'))
|
| 292 |
-
for k, v in params.__dict__.items():
|
| 293 |
-
vars(args)[k] = v
|
| 294 |
-
|
| 295 |
-
# Initialize tensorboard writer
|
| 296 |
-
tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard')
|
| 297 |
-
args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch)
|
| 298 |
-
if args.wandb:
|
| 299 |
-
import wandb
|
| 300 |
-
wandb.init(
|
| 301 |
-
project='Semaudio', sync_tensorboard=True,
|
| 302 |
-
dir=tensorboard_dir, name=os.path.basename(args.exp_dir))
|
| 303 |
-
|
| 304 |
-
exec("import %s as network" % args.model)
|
| 305 |
-
logging.info("Imported the model from '%s'." % args.model)
|
| 306 |
-
|
| 307 |
-
train(args)
|
| 308 |
-
|
| 309 |
-
args.writer.close()
|
| 310 |
-
if args.wandb:
|
| 311 |
-
wandb.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|