# Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. # Modified by Chaodong Xiao. """ import io import os import time from collections import defaultdict, deque import datetime import torch import torch.distributed as dist from typing import Any, NewType from torch.autograd import Function class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def _load_checkpoint_for_ema(model_ema, checkpoint): """ Workaround for ModelEma._load_checkpoint to accept an already-loaded object """ mem_file = io.BytesIO() torch.save({'state_dict_ema':checkpoint}, mem_file) mem_file.seek(0) model_ema._load_checkpoint(mem_file) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format( args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0) """ Quantization """ BinaryTensor = NewType('BinaryTensor', torch.Tensor) # A type where each element is in {-1, 1} def binary_sign(x: torch.Tensor) -> BinaryTensor: """Return -1 if x < 0, 1 if x >= 0.""" return x.sign() + (x == 0).type(torch.float) class STESign(Function): """ Binarize tensor using sign function. Straight-Through Estimator (STE) is used to approximate the gradient of sign function. """ @staticmethod def forward(ctx: Any, x: torch.Tensor) -> BinaryTensor: """ Return a Sign tensor. Args: ctx: context x: input tensor Returns: Sign(x) = (x>=0) - (x<0) Output type is float tensor where each element is either -1 or 1. """ ctx.save_for_backward(x) sign_x = binary_sign(x) return sign_x @staticmethod def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: """ Compute gradient using STE. Args: ctx: context grad_output: gradient w.r.t. output of Sign Returns: Gradient w.r.t. input of the Sign function """ x, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[x.gt(1)] = 0 grad_input[x.lt(-1)] = 0 return grad_input binarize = STESign.apply class SymQuantizer(Function): """ uniform quantization """ @staticmethod def forward(ctx, input, clip_val, num_bits, layerwise=False): """ :param ctx: :param input: tensor to be quantized :param clip_val: clip val :param num_bits: number of bits :return: quantized tensor """ ctx.save_for_backward(input, clip_val) if layerwise: max_input = torch.max(torch.abs(input)).expand_as(input) else: assert input.ndimension() == 4 max_input = ( torch.max(torch.abs(input), dim=-2, keepdim=True)[0] .expand_as(input) .detach() ) s = (2 ** (num_bits - 1) - 1) / (max_input + 1e-6) output = torch.round(input * s).div(s + 1e-6) return output @staticmethod def backward(ctx, grad_output): """ :param ctx: saved non-clipped full-precision tensor and clip_val :param grad_output: gradient ert the quantized tensor :return: estimated gradient wrt the full-precision tensor """ input, clip_val = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input.ge(clip_val[1])] = 0 grad_input[input.le(clip_val[0])] = 0 return grad_input, None, None, None symquantize = SymQuantizer.apply