File size: 6,748 Bytes
dbbd709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
"""A collection of useful helper functions"""
import os
import logging
import json
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import pandas as pd
from torchmetrics.functional import(
scale_invariant_signal_noise_ratio as si_snr,
signal_noise_ratio as snr,
signal_distortion_ratio as sdr,
scale_invariant_signal_distortion_ratio as si_sdr)
import matplotlib.pyplot as plt
class Params():
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5 # change the value of learning_rate in params
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
def dict(self):
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
return self.__dict__
def save_graph(train_metrics, test_metrics, save_dir):
metrics = [snr, si_snr]
results = {'train_loss': train_metrics['loss'],
'test_loss' : test_metrics['loss']}
for m_fn in metrics:
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
results_pd = pd.DataFrame(results)
results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
axs=[]
for i in temp_ax:
for j in i:
axs.append(j)
x = range(len(train_metrics['loss']))
axs[0].plot(x, train_metrics['loss'], label='train')
axs[0].plot(x, test_metrics['loss'], label='test')
axs[0].set(ylabel='Loss')
axs[0].set(xlabel='Epoch')
axs[0].set_title('loss',fontweight='bold')
axs[0].legend()
for i in range(len(metrics)):
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
axs[i+1].set(xlabel='Epoch')
axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
axs[i+1].legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'results.png'))
plt.close(fig)
def set_logger(log_path):
"""Set the logger to log info in terminal and file `log_path`.
In general, it is useful to have a logger so that every output to the terminal is saved
in a permanent file. Here we save it to `model_dir/train.log`.
Example:
```
logging.info("Starting training...")
```
Args:
log_path: (string) where to log
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers.clear()
# Logging to a file
file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
# Logging to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
"""Loads model parameters (state_dict) from file_path.
Args:
checkpoint: (string) filename which needs to be loaded
model: (torch.nn.Module) model for which the parameters are loaded
data_parallel: (bool) if the model is a data parallel model
"""
if not os.path.exists(checkpoint):
raise("File doesn't exist {}".format(checkpoint))
state_dict = torch.load(checkpoint)
if data_parallel:
state_dict['model_state_dict'] = {
'module.' + k: state_dict['model_state_dict'][k]
for k in state_dict['model_state_dict'].keys()}
model.load_state_dict(state_dict['model_state_dict'])
if optim is not None:
optim.load_state_dict(state_dict['optim_state_dict'])
if lr_sched is not None:
lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
return state_dict['epoch'], state_dict['train_metrics'], \
state_dict['val_metrics']
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
train_metrics=None, val_metrics=None, data_parallel=False):
"""Saves model parameters (state_dict) to file_path.
Args:
checkpoint: (string) filename which needs to be loaded
model: (torch.nn.Module) model for which the parameters are loaded
data_parallel: (bool) if the model is a data parallel model
"""
if os.path.exists(checkpoint):
raise("File already exists {}".format(checkpoint))
model_state_dict = model.state_dict()
if data_parallel:
model_state_dict = {
k.partition('module.')[2]:
model_state_dict[k] for k in model_state_dict.keys()}
optim_state_dict = None if not optim else optim.state_dict()
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
state_dict = {
'epoch': epoch,
'model_state_dict': model_state_dict,
'optim_state_dict': optim_state_dict,
'lr_sched_state_dict': lr_sched_state_dict,
'train_metrics': train_metrics,
'val_metrics': val_metrics
}
torch.save(state_dict, checkpoint)
def model_size(model):
"""
Returns size of the `model` in millions of parameters.
"""
num_train_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
return num_train_params / 1e6
def run_time(model, inputs, profiling=False):
"""
Returns runtime of a model in ms.
"""
# Warmup
for _ in range(100):
output = model(*inputs)
with profile(activities=[ProfilerActivity.CPU],
record_shapes=True) as prof:
with record_function("model_inference"):
output = model(*inputs)
# Print profiling results
if profiling:
print(prof.key_averages().table(sort_by="self_cpu_time_total",
row_limit=20))
# Return runtime in ms
return prof.profiler.self_cpu_time_total / 1000
def format_lr_info(optimizer):
lr_info = ""
for i, pg in enumerate(optimizer.param_groups):
lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
return lr_info
|