Spaces:
Running
on
T4
Running
on
T4
| import glob | |
| import json | |
| import os | |
| import pathlib | |
| import random | |
| import re | |
| import sys | |
| import time | |
| import matplotlib.pylab as plt | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from torch import distributed as dist | |
| from torch.nn.utils import weight_norm | |
| def seed_everything(seed, cudnn_deterministic=False): | |
| """ | |
| Function that sets seed for pseudo-random number generators in: | |
| pytorch, numpy, python.random | |
| Args: | |
| seed: the integer value seed for global random state | |
| """ | |
| if seed is not None: | |
| # print(f"Global seed set to {seed}") | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # if cudnn_deterministic: | |
| # torch.backends.cudnn.deterministic = True | |
| # warnings.warn('You have chosen to seed training. ' | |
| # 'This will turn on the CUDNN deterministic setting, ' | |
| # 'which can slow down your training considerably! ' | |
| # 'You may see unexpected behavior when restarting ' | |
| # 'from checkpoints.') | |
| def is_primary(): | |
| return get_rank() == 0 | |
| def get_rank(): | |
| if not dist.is_available(): | |
| return 0 | |
| if not dist.is_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| def load_yaml_config(path): | |
| with open(path) as f: | |
| config = yaml.full_load(f) | |
| return config | |
| def save_config_to_yaml(config, path): | |
| assert path.endswith('.yaml') | |
| with open(path, 'w') as f: | |
| f.write(yaml.dump(config)) | |
| f.close() | |
| def save_dict_to_json(d, path, indent=None): | |
| json.dump(d, open(path, 'w'), indent=indent) | |
| def load_dict_from_json(path): | |
| return json.load(open(path, 'r')) | |
| def write_args(args, path): | |
| args_dict = dict((name, getattr(args, name)) for name in dir(args) | |
| if not name.startswith('_')) | |
| with open(path, 'a') as args_file: | |
| args_file.write('==> torch version: {}\n'.format(torch.__version__)) | |
| args_file.write( | |
| '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) | |
| args_file.write('==> Cmd:\n') | |
| args_file.write(str(sys.argv)) | |
| args_file.write('\n==> args:\n') | |
| for k, v in sorted(args_dict.items()): | |
| args_file.write(' %s: %s\n' % (str(k), str(v))) | |
| args_file.close() | |
| class Logger(object): | |
| def __init__(self, args): | |
| self.args = args | |
| self.save_dir = args.save_dir | |
| self.is_primary = is_primary() | |
| if self.is_primary: | |
| os.makedirs(self.save_dir, exist_ok=True) | |
| # save the args and config | |
| self.config_dir = os.path.join(self.save_dir, 'configs') | |
| os.makedirs(self.config_dir, exist_ok=True) | |
| file_name = os.path.join(self.config_dir, 'args.txt') | |
| write_args(args, file_name) | |
| log_dir = os.path.join(self.save_dir, 'logs') | |
| if not os.path.exists(log_dir): | |
| os.makedirs(log_dir, exist_ok=True) | |
| self.text_writer = open(os.path.join(log_dir, 'log.txt'), | |
| 'a') # 'w') | |
| if args.tensorboard: | |
| self.log_info('using tensorboard') | |
| self.tb_writer = torch.utils.tensorboard.SummaryWriter( | |
| log_dir=log_dir | |
| ) # tensorboard.SummaryWriter(log_dir=log_dir) | |
| else: | |
| self.tb_writer = None | |
| def save_config(self, config): | |
| if self.is_primary: | |
| save_config_to_yaml(config, | |
| os.path.join(self.config_dir, 'config.yaml')) | |
| def log_info(self, info, check_primary=True): | |
| if self.is_primary or (not check_primary): | |
| print(info) | |
| if self.is_primary: | |
| info = str(info) | |
| time_str = time.strftime('%Y-%m-%d-%H-%M') | |
| info = '{}: {}'.format(time_str, info) | |
| if not info.endswith('\n'): | |
| info += '\n' | |
| self.text_writer.write(info) | |
| self.text_writer.flush() | |
| def add_scalar(self, **kargs): | |
| """Log a scalar variable.""" | |
| if self.is_primary: | |
| if self.tb_writer is not None: | |
| self.tb_writer.add_scalar(**kargs) | |
| def add_scalars(self, **kargs): | |
| """Log a scalar variable.""" | |
| if self.is_primary: | |
| if self.tb_writer is not None: | |
| self.tb_writer.add_scalars(**kargs) | |
| def add_image(self, **kargs): | |
| """Log a scalar variable.""" | |
| if self.is_primary: | |
| if self.tb_writer is not None: | |
| self.tb_writer.add_image(**kargs) | |
| def add_images(self, **kargs): | |
| """Log a scalar variable.""" | |
| if self.is_primary: | |
| if self.tb_writer is not None: | |
| self.tb_writer.add_images(**kargs) | |
| def close(self): | |
| if self.is_primary: | |
| self.text_writer.close() | |
| self.tb_writer.close() | |
| def plot_spectrogram(spectrogram): | |
| fig, ax = plt.subplots(figsize=(10, 2)) | |
| im = ax.imshow( | |
| spectrogram, aspect="auto", origin="lower", interpolation='none') | |
| plt.colorbar(im, ax=ax) | |
| fig.canvas.draw() | |
| plt.close() | |
| return fig | |
| def init_weights(m, mean=0.0, std=0.01): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| m.weight.data.normal_(mean, std) | |
| def apply_weight_norm(m): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| weight_norm(m) | |
| def get_padding(kernel_size, dilation=1): | |
| return int((kernel_size * dilation - dilation) / 2) | |
| def load_checkpoint(filepath, device): | |
| assert os.path.isfile(filepath) | |
| print("Loading '{}'".format(filepath)) | |
| checkpoint_dict = torch.load(filepath, map_location=device) | |
| print("Complete.") | |
| return checkpoint_dict | |
| def save_checkpoint(filepath, obj, num_ckpt_keep=5): | |
| name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1) | |
| ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*')) | |
| if len(ckpts) > num_ckpt_keep: | |
| [os.remove(c) for c in ckpts[:-num_ckpt_keep]] | |
| print("Saving checkpoint to {}".format(filepath)) | |
| torch.save(obj, filepath) | |
| print("Complete.") | |
| def scan_checkpoint(cp_dir, prefix): | |
| pattern = os.path.join(cp_dir, prefix + '????????') | |
| cp_list = glob.glob(pattern) | |
| if len(cp_list) == 0: | |
| return None | |
| return sorted(cp_list)[-1] | |