| import sys | |
| import wandb | |
| from time import sleep | |
| import os | |
| def init_wandb(project_name, model_name, config, **wandb_kwargs): | |
| os.environ['WANDB__SERVICE_WAIT'] = '300' | |
| while True: | |
| try: | |
| wandb_run = wandb.init( | |
| project=project_name, name=model_name, save_code=True, | |
| config=config, **wandb_kwargs, | |
| ) | |
| break | |
| except Exception as e: | |
| print('wandb connection error', file=sys.stderr) | |
| print(f'error: {e}', file=sys.stderr) | |
| sleep(1) | |
| print('retrying..', file=sys.stderr) | |
| return wandb_run | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |
| return True | |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |
| return False | |
| else: | |
| raise ValueError | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self, name, fmt=':f'): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
| return fmtstr.format(**self.__dict__) |