CausalStyleAdv / utils /__init__.py
YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
import os
import torch
import numpy as np
import random
import collections
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def save_model(args, model):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
torch.save(model_to_save.state_dict(), model_checkpoint)
def load_model(args, model):
model_to_save = model.module if hasattr(model, 'module') else model
model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name)
model.load_state_dict(torch.load(model_checkpoint, map_location='cpu'))
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpus > 0:
torch.cuda.manual_seed_all(args.seed)
def to_device(input, device):
if torch.is_tensor(input):
return input.to(device=device, non_blocking=True)
elif isinstance(input, str):
return input
elif isinstance(input, collections.Mapping):
return {k: to_device(sample, device=device) for k, sample in input.items()}
elif isinstance(input, collections.Sequence):
return [to_device(sample, device=device) for sample in input]
else:
raise TypeError("Input must contain tensor, dict or list, found {type(input)}")