Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from easydict import EasyDict | |
| from .misc import BlackHole | |
| def get_optimizer(cfg, model): | |
| if cfg.type == 'adam': | |
| return torch.optim.Adam( | |
| model.parameters(), | |
| lr=cfg.lr, | |
| weight_decay=cfg.weight_decay, | |
| betas=(cfg.beta1, cfg.beta2, ) | |
| ) | |
| else: | |
| raise NotImplementedError('Optimizer not supported: %s' % cfg.type) | |
| def get_scheduler(cfg, optimizer): | |
| if cfg.type is None: | |
| return BlackHole() | |
| elif cfg.type == 'plateau': | |
| return torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| factor=cfg.factor, | |
| patience=cfg.patience, | |
| min_lr=cfg.min_lr, | |
| ) | |
| elif cfg.type == 'multistep': | |
| return torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| milestones=cfg.milestones, | |
| gamma=cfg.gamma, | |
| ) | |
| elif cfg.type == 'exp': | |
| return torch.optim.lr_scheduler.ExponentialLR( | |
| optimizer, | |
| gamma=cfg.gamma, | |
| ) | |
| elif cfg.type is None: | |
| return BlackHole() | |
| else: | |
| raise NotImplementedError('Scheduler not supported: %s' % cfg.type) | |
| def get_warmup_sched(cfg, optimizer): | |
| if cfg is None: return BlackHole() | |
| lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups] | |
| warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas) | |
| return warmup_sched | |
| def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}): | |
| logstr = '[%s] Iter %05d' % (tag, it) | |
| logstr += ' | loss %.4f' % out['overall'].item() | |
| for k, v in out.items(): | |
| if k == 'overall': continue | |
| logstr += ' | loss(%s) %.4f' % (k, v.item()) | |
| for k, v in others.items(): | |
| logstr += ' | %s %2.4f' % (k, v) | |
| logger.info(logstr) | |
| for k, v in out.items(): | |
| if k == 'overall': | |
| writer.add_scalar('%s/loss' % tag, v, it) | |
| else: | |
| writer.add_scalar('%s/loss_%s' % (tag, k), v, it) | |
| for k, v in others.items(): | |
| writer.add_scalar('%s/%s' % (tag, k), v, it) | |
| writer.flush() | |
| class ValidationLossTape(object): | |
| def __init__(self): | |
| super().__init__() | |
| self.accumulate = {} | |
| self.others = {} | |
| self.total = 0 | |
| def update(self, out, n, others={}): | |
| self.total += n | |
| for k, v in out.items(): | |
| if k not in self.accumulate: | |
| self.accumulate[k] = v.clone().detach() | |
| else: | |
| self.accumulate[k] += v.clone().detach() | |
| for k, v in others.items(): | |
| if k not in self.others: | |
| self.others[k] = v.clone().detach() | |
| else: | |
| self.others[k] += v.clone().detach() | |
| def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'): | |
| avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()}) | |
| avg_others = EasyDict({k:v / self.total for k, v in self.others.items()}) | |
| log_losses(avg, it, tag, logger, writer, others=avg_others) | |
| return avg['overall'] | |
| def recursive_to(obj, device): | |
| if isinstance(obj, torch.Tensor): | |
| if device == 'cpu': | |
| return obj.cpu() | |
| try: | |
| return obj.cuda(device=device, non_blocking=True) | |
| except RuntimeError: | |
| return obj.to(device) | |
| elif isinstance(obj, list): | |
| return [recursive_to(o, device=device) for o in obj] | |
| elif isinstance(obj, tuple): | |
| return tuple(recursive_to(o, device=device) for o in obj) | |
| elif isinstance(obj, dict): | |
| return {k: recursive_to(v, device=device) for k, v in obj.items()} | |
| else: | |
| return obj | |
| def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'): | |
| if mode == 'sqrt': | |
| w = np.sqrt(length / max_length) | |
| elif mode == 'linear': | |
| w = length / max_length | |
| elif mode is None: | |
| w = 1.0 | |
| else: | |
| raise ValueError('Unknown reweighting mode: %s' % mode) | |
| return w | |
| def sum_weighted_losses(losses, weights): | |
| """ | |
| Args: | |
| losses: Dict of scalar tensors. | |
| weights: Dict of weights. | |
| """ | |
| loss = 0 | |
| for k in losses.keys(): | |
| if weights is None: | |
| loss = loss + losses[k] | |
| else: | |
| loss = loss + weights[k] * losses[k] | |
| return loss | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters()) | |