import os import torch import numpy as np import random import collections class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def simple_accuracy(preds, labels): return (preds == labels).mean() def save_model(args, model): model_to_save = model.module if hasattr(model, 'module') else model model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) torch.save(model_to_save.state_dict(), model_checkpoint) def load_model(args, model): model_to_save = model.module if hasattr(model, 'module') else model model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) model.load_state_dict(torch.load(model_checkpoint, map_location='cpu')) def count_parameters(model): params = sum(p.numel() for p in model.parameters() if p.requires_grad) return params/1000000 def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.gpus > 0: torch.cuda.manual_seed_all(args.seed) def to_device(input, device): if torch.is_tensor(input): return input.to(device=device, non_blocking=True) elif isinstance(input, str): return input elif isinstance(input, collections.Mapping): return {k: to_device(sample, device=device) for k, sample in input.items()} elif isinstance(input, collections.Sequence): return [to_device(sample, device=device) for sample in input] else: raise TypeError("Input must contain tensor, dict or list, found {type(input)}")