import torch, torchaudio import numpy as np from pathlib import Path from omegaconf import OmegaConf from hydra.utils import instantiate import datetime import sys import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # silent deprecation_warning() from datapipes import warnings warnings.filterwarnings("ignore", category=UserWarning, module='torchdata') from src import models from src.generators import get_datapipe device = 'cuda' if torch.cuda.is_available() else 'cpu' if len(sys.argv) < 2: print("Please provide the run directory as an argument.") sys.exit(1) run_dir = Path(sys.argv[1]) config = OmegaConf.load(run_dir / "config.yaml") if (run_dir / "train.txt").is_file(): minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2]) else: minimum = float("inf") # initialization model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device) model_weights_file = run_dir / f'model.pt' optimizer = torch.optim.Adam(model.parameters()) optimizer_weights_file = run_dir / f'optimizer.pt' def evaluate_model(stage=0, epoch=0): plot_dir = run_dir / "plots" / f"{stage}_{epoch}" plot_dir.mkdir(exist_ok=True, parents=True) torch.save(model.state_dict(), plot_dir / "model.pt") torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt") num_plots = config.logging.num_plots pipe = get_datapipe( **config.data, include_response_function=True, batch_size=num_plots ) batch = next(iter(pipe)) with torch.no_grad(): out = model(batch['noised_spectrum'].to(device)) noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu() for i in range(num_plots): plt.figure(figsize=(30,6)) plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0]) plt.plot(out['denoised'].cpu().numpy()[i,0]) plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png") plt.figure(figsize=(30,6)) plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0]) plt.plot(noised_est.cpu().numpy()[i,0]) plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png") plt.figure(figsize=(10,6)) plt.plot(batch['response_function'].cpu().numpy()[i,0,0]) plt.plot(out['response'].cpu().numpy()[i]) plt.savefig(plot_dir / f"{i:03d}_response.png") if "attention" in out: plt.figure(figsize=(10, 6)) plt.plot(out['attention'].cpu().numpy()[i]) plt.savefig(plot_dir / f"{i:03d}_attention.png") plt.close("all") for i_stage, training_stage in enumerate(config.training): if model_weights_file.is_file(): model.load_state_dict(torch.load(model_weights_file, weights_only=True)) if optimizer_weights_file.is_file(): optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True)) optimizer.param_groups[0]['lr'] = training_stage.learning_rate pipe = get_datapipe( **config.data, include_response_function=True, batch_size=training_stage.batch_size ) losses_history = [] losses_history_limit = 64*100 // training_stage.batch_size last_evaluation = 0 for epoch, batch in pipe.enumerate(): # logging iters_done = epoch*training_stage.batch_size if (iters_done - last_evaluation) > config.logging.step: evaluate_model(i_stage, epoch) last_evaluation = iters_done if iters_done > training_stage.max_iters: evaluate_model(i_stage, epoch) break # run model out = model(batch['noised_spectrum'].to(device)) # calculate losses loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device)) loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device)) noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same") loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device)) loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised # logging losses_history.append(loss_clean.item()) losses_history = losses_history[-losses_history_limit:] loss_avg = sum(losses_history)/len(losses_history) message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}" # message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify) with open(run_dir / f'train.txt', 'a') as f: f.write(message + '\n') print(message, flush = True) # save best if loss_avg < minimum: minimum = loss_avg torch.save(model.state_dict(), model_weights_file) torch.save(optimizer.state_dict(),optimizer_weights_file) # update weights loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) optimizer.step() optimizer.zero_grad()