File size: 5,425 Bytes
64b4096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch, torchaudio
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf
from hydra.utils import instantiate
import datetime
import sys
import matplotlib.pyplot as plt


import matplotlib
matplotlib.use('Agg')

# silent deprecation_warning() from datapipes
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')

from src import models
from src.generators import get_datapipe

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if len(sys.argv) < 2:
    print("Please provide the run directory as an argument.")
    sys.exit(1)
run_dir = Path(sys.argv[1])

config = OmegaConf.load(run_dir / "config.yaml")

if (run_dir / "train.txt").is_file():
    minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2])
else:
    minimum = float("inf")

# initialization        
model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
model_weights_file = run_dir / f'model.pt'
optimizer = torch.optim.Adam(model.parameters())
optimizer_weights_file = run_dir / f'optimizer.pt'

def evaluate_model(stage=0, epoch=0):
    plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
    plot_dir.mkdir(exist_ok=True, parents=True)

    torch.save(model.state_dict(), plot_dir / "model.pt")
    torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
    
    num_plots = config.logging.num_plots
    pipe = get_datapipe(
            **config.data,
            include_response_function=True,
            batch_size=num_plots
        )
    batch = next(iter(pipe))

    with torch.no_grad():
        out = model(batch['noised_spectrum'].to(device))
        noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu()

        for i in range(num_plots):
            plt.figure(figsize=(30,6))
            plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0])
            plt.plot(out['denoised'].cpu().numpy()[i,0])
            plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png")

            plt.figure(figsize=(30,6))
            plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0])
            plt.plot(noised_est.cpu().numpy()[i,0])
            plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png")

            plt.figure(figsize=(10,6))
            plt.plot(batch['response_function'].cpu().numpy()[i,0,0])
            plt.plot(out['response'].cpu().numpy()[i])
            plt.savefig(plot_dir / f"{i:03d}_response.png")

            if "attention" in out:
                plt.figure(figsize=(10, 6))
                plt.plot(out['attention'].cpu().numpy()[i])
                plt.savefig(plot_dir / f"{i:03d}_attention.png")
            
            plt.close("all")

for i_stage, training_stage in enumerate(config.training):
    if model_weights_file.is_file():
        model.load_state_dict(torch.load(model_weights_file, weights_only=True))

    if optimizer_weights_file.is_file():
        optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
    optimizer.param_groups[0]['lr'] = training_stage.learning_rate
    
    pipe = get_datapipe(
        **config.data,
        include_response_function=True,
        batch_size=training_stage.batch_size
    )

    losses_history = []
    losses_history_limit = 64*100 // training_stage.batch_size
    
    last_evaluation = 0
    for epoch, batch in pipe.enumerate():
        
        # logging
        iters_done = epoch*training_stage.batch_size
        if (iters_done - last_evaluation) > config.logging.step:
            evaluate_model(i_stage, epoch)
            last_evaluation = iters_done
            
        if  iters_done > training_stage.max_iters:
            evaluate_model(i_stage, epoch)
            break
        
        # run model
        out = model(batch['noised_spectrum'].to(device))
        # calculate losses
        loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device))
        loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device))
        noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same")
        loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device))
        loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised
        
        # logging
        losses_history.append(loss_clean.item())
        losses_history = losses_history[-losses_history_limit:]
        loss_avg = sum(losses_history)/len(losses_history)
        message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}"
        # message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify)
        with open(run_dir / f'train.txt', 'a') as f:
            f.write(message + '\n')
        print(message, flush = True)
        
        # save best
        if loss_avg < minimum:
            minimum = loss_avg
            torch.save(model.state_dict(), model_weights_file)
            torch.save(optimizer.state_dict(),optimizer_weights_file)
        
        # update weights
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()
        optimizer.zero_grad()