Spaces:
Sleeping
Sleeping
File size: 5,425 Bytes
64b4096 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch, torchaudio
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf
from hydra.utils import instantiate
import datetime
import sys
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
# silent deprecation_warning() from datapipes
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
from src import models
from src.generators import get_datapipe
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if len(sys.argv) < 2:
print("Please provide the run directory as an argument.")
sys.exit(1)
run_dir = Path(sys.argv[1])
config = OmegaConf.load(run_dir / "config.yaml")
if (run_dir / "train.txt").is_file():
minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2])
else:
minimum = float("inf")
# initialization
model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
model_weights_file = run_dir / f'model.pt'
optimizer = torch.optim.Adam(model.parameters())
optimizer_weights_file = run_dir / f'optimizer.pt'
def evaluate_model(stage=0, epoch=0):
plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
plot_dir.mkdir(exist_ok=True, parents=True)
torch.save(model.state_dict(), plot_dir / "model.pt")
torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
num_plots = config.logging.num_plots
pipe = get_datapipe(
**config.data,
include_response_function=True,
batch_size=num_plots
)
batch = next(iter(pipe))
with torch.no_grad():
out = model(batch['noised_spectrum'].to(device))
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu()
for i in range(num_plots):
plt.figure(figsize=(30,6))
plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0])
plt.plot(out['denoised'].cpu().numpy()[i,0])
plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png")
plt.figure(figsize=(30,6))
plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0])
plt.plot(noised_est.cpu().numpy()[i,0])
plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png")
plt.figure(figsize=(10,6))
plt.plot(batch['response_function'].cpu().numpy()[i,0,0])
plt.plot(out['response'].cpu().numpy()[i])
plt.savefig(plot_dir / f"{i:03d}_response.png")
if "attention" in out:
plt.figure(figsize=(10, 6))
plt.plot(out['attention'].cpu().numpy()[i])
plt.savefig(plot_dir / f"{i:03d}_attention.png")
plt.close("all")
for i_stage, training_stage in enumerate(config.training):
if model_weights_file.is_file():
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
if optimizer_weights_file.is_file():
optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
optimizer.param_groups[0]['lr'] = training_stage.learning_rate
pipe = get_datapipe(
**config.data,
include_response_function=True,
batch_size=training_stage.batch_size
)
losses_history = []
losses_history_limit = 64*100 // training_stage.batch_size
last_evaluation = 0
for epoch, batch in pipe.enumerate():
# logging
iters_done = epoch*training_stage.batch_size
if (iters_done - last_evaluation) > config.logging.step:
evaluate_model(i_stage, epoch)
last_evaluation = iters_done
if iters_done > training_stage.max_iters:
evaluate_model(i_stage, epoch)
break
# run model
out = model(batch['noised_spectrum'].to(device))
# calculate losses
loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device))
loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device))
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same")
loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device))
loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised
# logging
losses_history.append(loss_clean.item())
losses_history = losses_history[-losses_history_limit:]
loss_avg = sum(losses_history)/len(losses_history)
message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}"
# message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify)
with open(run_dir / f'train.txt', 'a') as f:
f.write(message + '\n')
print(message, flush = True)
# save best
if loss_avg < minimum:
minimum = loss_avg
torch.save(model.state_dict(), model_weights_file)
torch.save(optimizer.state_dict(),optimizer_weights_file)
# update weights
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
optimizer.zero_grad()
|