|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import wandb |
|
|
import torch |
|
|
from numpy import mean |
|
|
from src.metrics.metrics import Metrics |
|
|
import src.utils as utils |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class FakeModel(nn.Module): |
|
|
def __init__(self, model): |
|
|
super(FakeModel, self).__init__() |
|
|
self.model = model |
|
|
|
|
|
|
|
|
class PLModule(object): |
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
model_params, |
|
|
sr, |
|
|
optimizer, |
|
|
optimizer_params, |
|
|
scheduler=None, |
|
|
scheduler_params=None, |
|
|
loss=None, |
|
|
loss_params=None, |
|
|
metrics=[], |
|
|
slow_model_ckpt=None, |
|
|
prev_ckpt=None, |
|
|
grad_clip=None, |
|
|
use_dp=True, |
|
|
val_log_interval=10, |
|
|
samples_per_speaker_number=3, |
|
|
freeze_model1=False, |
|
|
): |
|
|
self.model = utils.import_attr(model)(**model_params) |
|
|
|
|
|
self.use_dp = use_dp |
|
|
if use_dp: |
|
|
self.model = nn.DataParallel(self.model) |
|
|
|
|
|
self.sr = sr |
|
|
|
|
|
|
|
|
|
|
|
self.samples_per_speaker_number = samples_per_speaker_number |
|
|
|
|
|
|
|
|
self.metrics = [Metrics(metric) for metric in metrics] |
|
|
|
|
|
|
|
|
self.metric_values = {} |
|
|
|
|
|
|
|
|
self.statistics = {} |
|
|
|
|
|
|
|
|
|
|
|
self.monitor = "val/loss" |
|
|
self.monitor_mode = "min" |
|
|
|
|
|
|
|
|
self.mode = None |
|
|
|
|
|
self.val_samples = {} |
|
|
self.train_samples = {} |
|
|
|
|
|
self.input_snr_calculated = False |
|
|
self.input_snr = [] |
|
|
self.snr_metric = Metrics("snr") |
|
|
|
|
|
|
|
|
self.loss_fn = utils.import_attr(loss)(**loss_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prev_ckpt is not None: |
|
|
if prev_ckpt.endswith(".ckpt"): |
|
|
print("load prev model", prev_ckpt) |
|
|
state = torch.load(prev_ckpt)["state_dict"] |
|
|
|
|
|
print(state["current_epoch"]) |
|
|
if self.use_dp: |
|
|
_model = self.model.module |
|
|
else: |
|
|
_model = self.model |
|
|
|
|
|
mdl = FakeModel(_model) |
|
|
mdl.load_state_dict(state) |
|
|
self.model = nn.DataParallel(mdl.model) |
|
|
else: |
|
|
print("load prev model", prev_ckpt) |
|
|
|
|
|
state = torch.load(prev_ckpt) |
|
|
print(state["current_epoch"]) |
|
|
state = state["model"] |
|
|
if self.use_dp: |
|
|
self.model.module.load_state_dict(state) |
|
|
else: |
|
|
self.model.load_state_dict(state) |
|
|
|
|
|
|
|
|
elif slow_model_ckpt is not None: |
|
|
print(f"Loading model 1 weights from checkpoint: {slow_model_ckpt}") |
|
|
model1_ckpt = torch.load(slow_model_ckpt) |
|
|
print("current epoch is {}".format(model1_ckpt["current_epoch"])) |
|
|
|
|
|
model1_state_dict = { |
|
|
key.replace("tce_model.", ""): value |
|
|
for key, value in model1_ckpt["model"].items() |
|
|
if key.startswith("tce_model.") |
|
|
} |
|
|
|
|
|
if self.use_dp: |
|
|
self.model.module.model1.load_state_dict(model1_state_dict, strict=False) |
|
|
else: |
|
|
self.model.model1.load_state_dict(model1_state_dict, strict=False) |
|
|
|
|
|
else: |
|
|
print("Loading model from scratch, no slow model init ckpt or joint model init ckpt") |
|
|
|
|
|
|
|
|
self.freeze = freeze_model1 |
|
|
if freeze_model1: |
|
|
self.freeze_model1() |
|
|
params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters()) |
|
|
|
|
|
self.optimizer = utils.import_attr(optimizer)(params_to_optimize, **optimizer_params) |
|
|
self.optim_name = optimizer |
|
|
self.opt_params = optimizer_params |
|
|
else: |
|
|
|
|
|
self.optimizer = utils.import_attr(optimizer)(self.model.parameters(), **optimizer_params) |
|
|
self.optim_name = optimizer |
|
|
self.opt_params = optimizer_params |
|
|
|
|
|
|
|
|
self.grad_clip = grad_clip |
|
|
|
|
|
if self.grad_clip is not None: |
|
|
print(f"USING GRAD CLIP: {self.grad_clip}") |
|
|
else: |
|
|
print("ERROR! NOT USING GRAD CLIP" * 100) |
|
|
|
|
|
|
|
|
self.scheduler = self.init_scheduler(scheduler, scheduler_params) |
|
|
self.scheduler_name = scheduler |
|
|
self.scheduler_params = scheduler_params |
|
|
|
|
|
self.epoch = 0 |
|
|
|
|
|
def freeze_model1(self): |
|
|
"""Freezes the weights of model1.""" |
|
|
print("Freezing model1 weights") |
|
|
model1 = self.model.module.model1 if self.use_dp else self.model.model1 |
|
|
for param in model1.parameters(): |
|
|
param.requires_grad = False |
|
|
print("Model1 weights frozen.") |
|
|
|
|
|
def load_state(self, path, map_location=None): |
|
|
state = torch.load(path, map_location=map_location) |
|
|
|
|
|
if self.use_dp: |
|
|
self.model.module.load_state_dict(state["model"]) |
|
|
else: |
|
|
self.model.load_state_dict(state["model"]) |
|
|
|
|
|
|
|
|
if not self.freeze: |
|
|
self.optimizer = utils.import_attr(self.optim_name)(self.model.parameters(), **self.opt_params) |
|
|
else: |
|
|
params_to_optimize = filter(lambda p: p.requires_grad, self.model.parameters()) |
|
|
self.optimizer = utils.import_attr(self.optim_name)(params_to_optimize, **self.opt_params) |
|
|
|
|
|
|
|
|
if self.scheduler is not None: |
|
|
self.scheduler = self.init_scheduler(self.scheduler_name, self.scheduler_params) |
|
|
|
|
|
self.optimizer.load_state_dict(state["optimizer"]) |
|
|
|
|
|
if self.scheduler is not None: |
|
|
self.scheduler.load_state_dict(state["scheduler"]) |
|
|
|
|
|
self.epoch = state["current_epoch"] |
|
|
print("Load model from epoch", self.epoch) |
|
|
self.metric_values = state["metric_values"] |
|
|
|
|
|
if "statistics" in self.statistics: |
|
|
self.statistics = state["statistics"] |
|
|
|
|
|
def dump_state(self, path): |
|
|
if self.use_dp: |
|
|
_model = self.model.module |
|
|
else: |
|
|
_model = self.model |
|
|
|
|
|
state = dict( |
|
|
model=_model.state_dict(), |
|
|
optimizer=self.optimizer.state_dict(), |
|
|
current_epoch=self.epoch, |
|
|
metric_values=self.metric_values, |
|
|
statistics=self.statistics, |
|
|
) |
|
|
|
|
|
if self.scheduler is not None: |
|
|
state["scheduler"] = self.scheduler.state_dict() |
|
|
print("save to " + path) |
|
|
torch.save(state, path) |
|
|
|
|
|
def get_current_lr(self): |
|
|
for param_group in self.optimizer.param_groups: |
|
|
return param_group["lr"] |
|
|
|
|
|
def on_epoch_start(self): |
|
|
print() |
|
|
print("=" * 25, "STARTING EPOCH", self.epoch, "=" * 25) |
|
|
print() |
|
|
|
|
|
def get_avg_metric_at_epoch(self, metric, epoch=None): |
|
|
if epoch is None: |
|
|
epoch = self.epoch |
|
|
|
|
|
return self.metric_values[epoch][metric]["epoch"] / self.metric_values[epoch][metric]["num_elements"] |
|
|
|
|
|
def on_epoch_end(self, best_path, wandb_run): |
|
|
assert self.epoch + 1 == len( |
|
|
self.metric_values |
|
|
), "Current epoch must be equal to length of metrics (0-indexed)" |
|
|
|
|
|
monitor_metric_last = self.get_avg_metric_at_epoch(self.monitor) |
|
|
|
|
|
|
|
|
save = True |
|
|
for epoch in range(len(self.metric_values) - 1): |
|
|
monitor_metric_at_epoch = self.get_avg_metric_at_epoch(self.monitor, epoch) |
|
|
|
|
|
if self.monitor_mode == "max": |
|
|
|
|
|
|
|
|
if monitor_metric_last < monitor_metric_at_epoch: |
|
|
save = False |
|
|
break |
|
|
|
|
|
if self.monitor_mode == "min": |
|
|
|
|
|
|
|
|
if monitor_metric_last > monitor_metric_at_epoch: |
|
|
save = False |
|
|
break |
|
|
|
|
|
|
|
|
if save: |
|
|
print("Current checkpoint is the best! Saving it...") |
|
|
self.dump_state(best_path) |
|
|
|
|
|
val_loss = self.get_avg_metric_at_epoch("val/loss") |
|
|
val_snr_i = self.get_avg_metric_at_epoch("val/snr_i") |
|
|
val_si_snr_i = self.get_avg_metric_at_epoch("val/si_snr_i") |
|
|
|
|
|
print(f"Val loss: {val_loss:.02f}") |
|
|
print(f"Val SNRi: {val_snr_i:.02f}dB") |
|
|
print(f"Val SI-SDRi: {val_si_snr_i:.02f}dB") |
|
|
|
|
|
|
|
|
wandb_run.log({"lr-Adam": self.get_current_lr()}, commit=False, step=self.epoch + 1) |
|
|
|
|
|
for metric in self.metric_values[self.epoch]: |
|
|
wandb_run.log({metric: self.get_avg_metric_at_epoch(metric)}, commit=False, step=self.epoch + 1) |
|
|
|
|
|
for statistic in self.statistics: |
|
|
if not self.statistics[statistic]["logged"]: |
|
|
data = self.statistics[statistic]["data"] |
|
|
reduction = self.statistics[statistic]["reduction"] |
|
|
if reduction == "mean": |
|
|
val = mean(data) |
|
|
elif reduction == "sum": |
|
|
val = sum(data) |
|
|
elif reduction == "histogram": |
|
|
data = [[d] for d in data] |
|
|
table = wandb.Table(data=data, columns=[statistic]) |
|
|
val = wandb.plot.histogram(table, statistic, title=statistic) |
|
|
else: |
|
|
assert 0, f"Unknown reduction {reduction}." |
|
|
wandb_run.log({statistic: val}, commit=False) |
|
|
self.statistics[statistic]["logged"] = True |
|
|
|
|
|
wandb_run.log({"epoch": self.epoch}, commit=True, step=self.epoch + 1) |
|
|
|
|
|
if self.scheduler is not None: |
|
|
if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau: |
|
|
|
|
|
self.scheduler.step(monitor_metric_last) |
|
|
else: |
|
|
self.scheduler.step() |
|
|
|
|
|
self.epoch += 1 |
|
|
|
|
|
def log_statistic(self, name, value, reduction="mean"): |
|
|
if name not in self.statistics: |
|
|
self.statistics[name] = dict(logged=False, data=[], reduction=reduction) |
|
|
|
|
|
self.statistics[name]["data"].append(value) |
|
|
|
|
|
def log_metric(self, name, value, batch_size=1, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True): |
|
|
""" |
|
|
Logs a metric |
|
|
value must be the AVERAGE value across the batch |
|
|
Must provide batch size for accurate average computation |
|
|
""" |
|
|
|
|
|
epoch_str = self.epoch |
|
|
if epoch_str not in self.metric_values: |
|
|
self.metric_values[epoch_str] = {} |
|
|
|
|
|
if name not in self.metric_values[epoch_str]: |
|
|
self.metric_values[epoch_str][name] = dict(step=None, epoch=None) |
|
|
|
|
|
if type(value) == torch.Tensor: |
|
|
value = value.item() |
|
|
|
|
|
if on_step: |
|
|
if self.metric_values[epoch_str][name]["step"] is None: |
|
|
self.metric_values[epoch_str][name]["step"] = [] |
|
|
|
|
|
self.metric_values[epoch_str][name]["step"].append(value) |
|
|
|
|
|
if on_epoch: |
|
|
if self.metric_values[epoch_str][name]["epoch"] is None: |
|
|
self.metric_values[epoch_str][name]["epoch"] = 0 |
|
|
self.metric_values[epoch_str][name]["num_elements"] = 0 |
|
|
|
|
|
self.metric_values[epoch_str][name]["epoch"] += value * batch_size |
|
|
self.metric_values[epoch_str][name]["num_elements"] += batch_size |
|
|
|
|
|
def val_naive(self, batch, batch_idx): |
|
|
inputs, targets = batch |
|
|
a = torch.cuda.memory_allocated(inputs["mixture"].device) |
|
|
outputs = self.model(inputs) |
|
|
b = torch.cuda.memory_allocated(inputs["mixture"].device) |
|
|
print("Infer consume M", (b - a) / 1e6) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def train_naive(self, batch, batch_idx): |
|
|
self.reset_grad() |
|
|
inputs, targets = batch |
|
|
a = torch.cuda.memory_allocated(inputs["mixture"].device) |
|
|
|
|
|
outputs = self.model(inputs) |
|
|
|
|
|
est = outputs["output"] |
|
|
gt = targets["target"] |
|
|
|
|
|
|
|
|
loss = self.loss_fn(est=est, gt=gt).mean() |
|
|
b = torch.cuda.memory_allocated(inputs["mixture"].device) |
|
|
|
|
|
loss.backward(retain_graph=True) |
|
|
c = torch.cuda.memory_allocated(inputs["mixture"].device) |
|
|
|
|
|
self.backprop() |
|
|
d = torch.cuda.memory_allocated(inputs["mixture"].device) |
|
|
|
|
|
print("Training consume G", (b - a) / 1e9, (c - a) / 1e9, (d - c) / 1e9, a / 1e9) |
|
|
return outputs |
|
|
|
|
|
def silence_audio(self, input, timestamp): |
|
|
output_audio = input.clone() |
|
|
for start, end in timestamp: |
|
|
output_audio[start:end] = 0.0 |
|
|
|
|
|
return output_audio |
|
|
|
|
|
def _step(self, batch, batch_idx, step="train"): |
|
|
inputs, targets = batch |
|
|
batch_size = inputs["mixture"].shape[0] |
|
|
|
|
|
start_idx = inputs["start_idx_list"][0].item() |
|
|
end_idx = inputs["end_idx_list"][0].item() |
|
|
inputs["start_idx"] = start_idx |
|
|
inputs["end_idx"] = end_idx |
|
|
|
|
|
outputs = self.model(inputs) |
|
|
est = outputs["output"].clone() |
|
|
|
|
|
if "audio_range" in outputs: |
|
|
audio_range = outputs["audio_range"] |
|
|
start_indices = audio_range[:, 0] |
|
|
end_indices = audio_range[:, 1] |
|
|
sliced_gt = [] |
|
|
sliced_mix = [] |
|
|
sliced_self = [] |
|
|
|
|
|
|
|
|
gt_clone = targets["target"].clone() |
|
|
mix_clone = inputs["mixture"][:, 0:1].clone() |
|
|
full_self_speech_clone = inputs["self_speech"].clone() |
|
|
|
|
|
for index in range(est.size(0)): |
|
|
start = start_indices[index].item() |
|
|
end = end_indices[index].item() |
|
|
|
|
|
sliced_gt.append(gt_clone[index, :, start:end]) |
|
|
sliced_mix.append(mix_clone[index, :, start:end]) |
|
|
sliced_self.append(full_self_speech_clone[index, :, start:end]) |
|
|
|
|
|
|
|
|
gt = torch.stack(sliced_gt, dim=0) |
|
|
mix = torch.stack(sliced_mix, dim=0) |
|
|
self_speech_final = torch.stack(sliced_self, dim=0) |
|
|
|
|
|
else: |
|
|
mix = inputs["mixture"][:, 0:1].clone() |
|
|
gt = targets["target"].clone() |
|
|
self_speech_final = targets["self_speech"].clone() |
|
|
|
|
|
|
|
|
loss = self.loss_fn(est=est, gt=gt).mean() |
|
|
|
|
|
est_detached = est.detach().clone() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
self.log_metric( |
|
|
f"{step}/loss", |
|
|
loss.item(), |
|
|
batch_size=batch_size, |
|
|
on_step=(step == "train"), |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
|
|
|
|
|
|
for metric in self.metrics: |
|
|
if step == "train" and (metric.name == "PESQ" or metric.name == "STOI"): |
|
|
continue |
|
|
metric_val = metric(est=est_detached, gt=gt, mix=mix, self_speech=self_speech_final) |
|
|
for i in range(batch_size): |
|
|
|
|
|
if torch.all(gt[i] == 0): |
|
|
|
|
|
continue |
|
|
val = metric_val[i].item() |
|
|
self.log_metric( |
|
|
f"{step}/{metric.name}", |
|
|
val, |
|
|
batch_size=1, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
|
|
|
|
|
|
sample = { |
|
|
"mixture": mix, |
|
|
"output": est_detached, |
|
|
"target": gt, |
|
|
} |
|
|
|
|
|
return loss, sample |
|
|
|
|
|
def train(self): |
|
|
self.model.train() |
|
|
self.mode = "train" |
|
|
|
|
|
def eval(self): |
|
|
self.model.eval() |
|
|
self.mode = "val" |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
loss, sample = self._step(batch, batch_idx, step="train") |
|
|
|
|
|
target = sample["target"] |
|
|
|
|
|
return loss, target.shape[0] |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
loss, sample = self._step(batch, batch_idx, step="val") |
|
|
|
|
|
target = sample["target"] |
|
|
|
|
|
return loss, target.shape[0] |
|
|
|
|
|
def reset_grad(self): |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
def backprop(self): |
|
|
|
|
|
|
|
|
|
|
|
if self.grad_clip is not None: |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
def configure_optimizers(self): |
|
|
if self.scheduler is not None: |
|
|
|
|
|
if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau: |
|
|
scheduler_cfg = { |
|
|
"scheduler": self.scheduler, |
|
|
"interval": "epoch", |
|
|
"frequency": 1, |
|
|
"monitor": self.monitor, |
|
|
"strict": False, |
|
|
} |
|
|
else: |
|
|
scheduler_cfg = self.scheduler |
|
|
return [self.optimizer], [scheduler_cfg] |
|
|
else: |
|
|
return self.optimizer |
|
|
|
|
|
def init_scheduler(self, scheduler, scheduler_params): |
|
|
if scheduler is not None: |
|
|
if scheduler == "sequential": |
|
|
schedulers = [] |
|
|
milestones = [] |
|
|
for scheduler_param in scheduler_params: |
|
|
sched = utils.import_attr(scheduler_param["name"])(self.optimizer, **scheduler_param["params"]) |
|
|
schedulers.append(sched) |
|
|
milestones.append(scheduler_param["epochs"]) |
|
|
|
|
|
|
|
|
for i in range(1, len(milestones)): |
|
|
milestones[i] = milestones[i - 1] + milestones[i] |
|
|
|
|
|
|
|
|
milestones.pop() |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.SequentialLR(self.optimizer, schedulers, milestones) |
|
|
else: |
|
|
scheduler = utils.import_attr(scheduler)(self.optimizer, **scheduler_params) |
|
|
|
|
|
return scheduler |
|
|
|