ClearSep / helpers /utils.py
Tianhao Wang
first commit
dbbd709
"""A collection of useful helper functions"""
import os
import logging
import json
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import pandas as pd
from torchmetrics.functional import(
scale_invariant_signal_noise_ratio as si_snr,
signal_noise_ratio as snr,
signal_distortion_ratio as sdr,
scale_invariant_signal_distortion_ratio as si_sdr)
import matplotlib.pyplot as plt
class Params():
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5 # change the value of learning_rate in params
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
def dict(self):
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
return self.__dict__
def save_graph(train_metrics, test_metrics, save_dir):
metrics = [snr, si_snr]
results = {'train_loss': train_metrics['loss'],
'test_loss' : test_metrics['loss']}
for m_fn in metrics:
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
results_pd = pd.DataFrame(results)
results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
axs=[]
for i in temp_ax:
for j in i:
axs.append(j)
x = range(len(train_metrics['loss']))
axs[0].plot(x, train_metrics['loss'], label='train')
axs[0].plot(x, test_metrics['loss'], label='test')
axs[0].set(ylabel='Loss')
axs[0].set(xlabel='Epoch')
axs[0].set_title('loss',fontweight='bold')
axs[0].legend()
for i in range(len(metrics)):
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
axs[i+1].set(xlabel='Epoch')
axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
axs[i+1].legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'results.png'))
plt.close(fig)
def set_logger(log_path):
"""Set the logger to log info in terminal and file `log_path`.
In general, it is useful to have a logger so that every output to the terminal is saved
in a permanent file. Here we save it to `model_dir/train.log`.
Example:
```
logging.info("Starting training...")
```
Args:
log_path: (string) where to log
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers.clear()
# Logging to a file
file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
# Logging to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
"""Loads model parameters (state_dict) from file_path.
Args:
checkpoint: (string) filename which needs to be loaded
model: (torch.nn.Module) model for which the parameters are loaded
data_parallel: (bool) if the model is a data parallel model
"""
if not os.path.exists(checkpoint):
raise("File doesn't exist {}".format(checkpoint))
state_dict = torch.load(checkpoint)
if data_parallel:
state_dict['model_state_dict'] = {
'module.' + k: state_dict['model_state_dict'][k]
for k in state_dict['model_state_dict'].keys()}
model.load_state_dict(state_dict['model_state_dict'])
if optim is not None:
optim.load_state_dict(state_dict['optim_state_dict'])
if lr_sched is not None:
lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
return state_dict['epoch'], state_dict['train_metrics'], \
state_dict['val_metrics']
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
train_metrics=None, val_metrics=None, data_parallel=False):
"""Saves model parameters (state_dict) to file_path.
Args:
checkpoint: (string) filename which needs to be loaded
model: (torch.nn.Module) model for which the parameters are loaded
data_parallel: (bool) if the model is a data parallel model
"""
if os.path.exists(checkpoint):
raise("File already exists {}".format(checkpoint))
model_state_dict = model.state_dict()
if data_parallel:
model_state_dict = {
k.partition('module.')[2]:
model_state_dict[k] for k in model_state_dict.keys()}
optim_state_dict = None if not optim else optim.state_dict()
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
state_dict = {
'epoch': epoch,
'model_state_dict': model_state_dict,
'optim_state_dict': optim_state_dict,
'lr_sched_state_dict': lr_sched_state_dict,
'train_metrics': train_metrics,
'val_metrics': val_metrics
}
torch.save(state_dict, checkpoint)
def model_size(model):
"""
Returns size of the `model` in millions of parameters.
"""
num_train_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
return num_train_params / 1e6
def run_time(model, inputs, profiling=False):
"""
Returns runtime of a model in ms.
"""
# Warmup
for _ in range(100):
output = model(*inputs)
with profile(activities=[ProfilerActivity.CPU],
record_shapes=True) as prof:
with record_function("model_inference"):
output = model(*inputs)
# Print profiling results
if profiling:
print(prof.key_averages().table(sort_by="self_cpu_time_total",
row_limit=20))
# Return runtime in ms
return prof.profiler.self_cpu_time_total / 1000
def format_lr_info(optimizer):
lr_info = ""
for i, pg in enumerate(optimizer.param_groups):
lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
return lr_info