Spaces:
Build error
Build error
| import os, sys, time | |
| import shutil | |
| import datetime | |
| import torch | |
| import torch.nn.functional as torch_F | |
| import socket | |
| import contextlib | |
| import socket | |
| import torch.distributed as dist | |
| from collections import defaultdict, deque | |
| class SmoothedValue(object): | |
| """Track a series of values and provide access to smoothed values over a | |
| window or the global series average. | |
| """ | |
| def __init__(self, window_size=20, fmt=None): | |
| if fmt is None: | |
| fmt = "{median:.4f} ({global_avg:.4f})" | |
| self.deque = deque(maxlen=window_size) | |
| self.total = 0.0 | |
| self.count = 0 | |
| self.fmt = fmt | |
| def update(self, value, n=1): | |
| self.deque.append(value) | |
| self.count += n | |
| self.total += value * n | |
| def median(self): | |
| d = torch.tensor(list(self.deque)) | |
| return d.median().item() | |
| def avg(self): | |
| d = torch.tensor(list(self.deque), dtype=torch.float32) | |
| return d.mean().item() | |
| def global_avg(self): | |
| return self.total / self.count | |
| def max(self): | |
| return max(self.deque) | |
| def value(self): | |
| return self.deque[-1] | |
| def __str__(self): | |
| return self.fmt.format( | |
| median=self.median, | |
| avg=self.avg, | |
| global_avg=self.global_avg, | |
| max=self.max, | |
| value=self.value) | |
| class MetricLogger(object): | |
| def __init__(self, delimiter="\t"): | |
| self.meters = defaultdict(SmoothedValue) | |
| self.delimiter = delimiter | |
| def update(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| if v is None: | |
| continue | |
| if isinstance(v, torch.Tensor): | |
| v = v.item() | |
| assert isinstance(v, (float, int)) | |
| self.meters[k].update(v) | |
| def __getattr__(self, attr): | |
| if attr in self.meters: | |
| return self.meters[attr] | |
| if attr in self.__dict__: | |
| return self.__dict__[attr] | |
| raise AttributeError("'{}' object has no attribute '{}'".format( | |
| type(self).__name__, attr)) | |
| def __str__(self): | |
| loss_str = [] | |
| for name, meter in self.meters.items(): | |
| loss_str.append( | |
| "{}: {}".format(name, str(meter)) | |
| ) | |
| return self.delimiter.join(loss_str) | |
| def add_meter(self, name, meter): | |
| self.meters[name] = meter | |
| def log_every(self, iterable, print_freq, header=None): | |
| i = 0 | |
| if not header: | |
| header = '' | |
| start_time = time.time() | |
| end = time.time() | |
| iter_time = SmoothedValue(fmt='{avg:.4f}') | |
| data_time = SmoothedValue(fmt='{avg:.4f}') | |
| space_fmt = ':' + str(len(str(len(iterable)))) + 'd' | |
| log_msg = [ | |
| header, | |
| '[{0' + space_fmt + '}/{1}]', | |
| 'eta: {eta}', | |
| '{meters}', | |
| 'time: {time}', | |
| 'data: {data}' | |
| ] | |
| if torch.cuda.is_available(): | |
| log_msg.append('max mem: {memory:.0f}') | |
| log_msg = self.delimiter.join(log_msg) | |
| MB = 1024.0 * 1024.0 | |
| for obj in iterable: | |
| data_time.update(time.time() - end) | |
| yield obj | |
| iter_time.update(time.time() - end) | |
| if i % print_freq == 0 or i == len(iterable) - 1: | |
| eta_seconds = iter_time.global_avg * (len(iterable) - i) | |
| eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
| if torch.cuda.is_available(): | |
| print(log_msg.format( | |
| i, len(iterable), eta=eta_string, | |
| meters=str(self), | |
| time=str(iter_time), data=str(data_time), | |
| memory=torch.cuda.max_memory_allocated() / MB)) | |
| else: | |
| print(log_msg.format( | |
| i, len(iterable), eta=eta_string, | |
| meters=str(self), | |
| time=str(iter_time), data=str(data_time))) | |
| i += 1 | |
| end = time.time() | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print('{} Total time: {} ({:.4f} s / it)'.format( | |
| header, total_time_str, total_time / len(iterable))) | |
| def print_eval(opt, loss=None, chamfer=None, depth_metrics=None): | |
| message = "[eval] " | |
| if loss is not None: message += "loss:{}".format("{:.3e}".format(loss.all)) | |
| if chamfer is not None: | |
| message += " chamfer:{}|{}|{}".format("{:.4f}".format(chamfer[0]), | |
| "{:.4f}".format(chamfer[1]), | |
| "{:.4f}".format((chamfer[0]+chamfer[1])/2)) | |
| if depth_metrics is not None: | |
| for k, v in depth_metrics.items(): | |
| message += "{}:{}, ".format(k, "{:.4f}".format(v)) | |
| message = message[:-2] | |
| print(message) | |
| def update_timer(opt, timer, ep, it_per_ep): | |
| momentum = 0.99 | |
| timer.elapsed = time.time()-timer.start | |
| timer.it = timer.it_end-timer.it_start | |
| # compute speed with moving average | |
| timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it | |
| timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep) | |
| # move tensors to device in-place | |
| def move_to_device(X, device): | |
| if isinstance(X, dict): | |
| for k, v in X.items(): | |
| X[k] = move_to_device(v, device) | |
| elif isinstance(X, list): | |
| for i, e in enumerate(X): | |
| X[i] = move_to_device(e, device) | |
| elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple | |
| dd = X._asdict() | |
| dd = move_to_device(dd, device) | |
| return type(X)(**dd) | |
| elif isinstance(X, torch.Tensor): | |
| return X.to(device=device, non_blocking=True) | |
| return X | |
| # detach tensors | |
| def detach_tensors(X): | |
| if isinstance(X, dict): | |
| for k, v in X.items(): | |
| X[k] = detach_tensors(v) | |
| elif isinstance(X, list): | |
| for i, e in enumerate(X): | |
| X[i] = detach_tensors(e) | |
| elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple | |
| dd = X._asdict() | |
| dd = detach_tensors(dd) | |
| return type(X)(**dd) | |
| elif isinstance(X, torch.Tensor): | |
| return X.detach() | |
| return X | |
| # this recursion seems to only work for the outer loop when dict_type is not dict | |
| def to_dict(D, dict_type=dict): | |
| D = dict_type(D) | |
| for k, v in D.items(): | |
| if isinstance(v, dict): | |
| D[k] = to_dict(v, dict_type) | |
| return D | |
| def get_child_state_dict(state_dict, key): | |
| out_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| param_name = k[7:] | |
| else: | |
| param_name = k | |
| if param_name.startswith("{}.".format(key)): | |
| out_dict[".".join(param_name.split(".")[1:])] = v | |
| return out_dict | |
| def resume_checkpoint(opt, model, best): | |
| load_name = "{0}/best.ckpt".format(opt.output_path) if best else "{0}/latest.ckpt".format(opt.output_path) | |
| checkpoint = torch.load(load_name, map_location=torch.device(opt.device)) | |
| model.graph.module.load_state_dict(checkpoint["graph"], strict=True) | |
| # load the training stats | |
| for key in model.__dict__: | |
| if key.split("_")[0] in ["optim", "sched", "scaler"] and key in checkpoint: | |
| if opt.device == 0: print("restoring {}...".format(key)) | |
| getattr(model, key).load_state_dict(checkpoint[key]) | |
| # also need to record ep, it, best_val if we are returning | |
| ep, it = checkpoint["epoch"], checkpoint["iter"] | |
| best_val, best_ep = checkpoint["best_val"], checkpoint["best_ep"] if "best_ep" in checkpoint else 0 | |
| print("resuming from epoch {0} (iteration {1})".format(ep, it)) | |
| return ep, it, best_val, best_ep | |
| def load_checkpoint(opt, model, load_name): | |
| # load_name as to be given | |
| checkpoint = torch.load(load_name, map_location=torch.device(opt.device)) | |
| # load individual (possibly partial) children modules | |
| for name, child in model.graph.module.named_children(): | |
| child_state_dict = get_child_state_dict(checkpoint["graph"], name) | |
| if child_state_dict: | |
| if opt.device == 0: print("restoring {}...".format(name)) | |
| child.load_state_dict(child_state_dict, strict=True) | |
| else: | |
| if opt.device == 0: print("skipping {}...".format(name)) | |
| return None, None, None, None | |
| def restore_checkpoint(opt, model, load_name=None, resume=False, best=False, evaluate=False): | |
| # we cannot load and resume at the same time | |
| assert not (load_name is not None and resume) | |
| # when resuming we want everything to be the same | |
| if resume: | |
| ep, it, best_val, best_ep = resume_checkpoint(opt, model, best) | |
| # loading is more flexible, as we can only load parts of the model | |
| else: | |
| ep, it, best_val, best_ep = load_checkpoint(opt, model, load_name) | |
| return ep, it, best_val, best_ep | |
| def save_checkpoint(opt, model, ep, it, best_val, best_ep, latest=False, best=False, children=None): | |
| os.makedirs("{0}/checkpoint".format(opt.output_path), exist_ok=True) | |
| if isinstance(model.graph, torch.nn.DataParallel) or isinstance(model.graph, torch.nn.parallel.DistributedDataParallel): | |
| graph = model.graph.module | |
| else: | |
| graph = model.graph | |
| if children is not None: | |
| graph_state_dict = { k: v for k, v in graph.state_dict().items() if k.startswith(children) } | |
| else: graph_state_dict = graph.state_dict() | |
| checkpoint = dict( | |
| epoch=ep, | |
| iter=it, | |
| best_val=best_val, | |
| best_ep=best_ep, | |
| graph=graph_state_dict, | |
| ) | |
| for key in model.__dict__: | |
| if key.split("_")[0] in ["optim", "sched", "scaler"]: | |
| checkpoint.update({key: getattr(model, key).state_dict()}) | |
| torch.save(checkpoint, "{0}/latest.ckpt".format(opt.output_path)) | |
| if best: | |
| shutil.copy("{0}/latest.ckpt".format(opt.output_path), | |
| "{0}/best.ckpt".format(opt.output_path)) | |
| if not latest: | |
| shutil.copy("{0}/latest.ckpt".format(opt.output_path), | |
| "{0}/checkpoint/ep{1}.ckpt".format(opt.output_path, ep)) | |
| def check_socket_open(hostname, port): | |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| is_open = False | |
| try: | |
| s.bind((hostname, port)) | |
| except socket.error: | |
| is_open = True | |
| finally: | |
| s.close() | |
| return is_open | |
| def get_layer_dims(layers): | |
| # return a list of tuples (k_in, k_out) | |
| return list(zip(layers[:-1], layers[1:])) | |
| def suppress(stdout=False, stderr=False): | |
| with open(os.devnull, "w") as devnull: | |
| if stdout: old_stdout, sys.stdout = sys.stdout, devnull | |
| if stderr: old_stderr, sys.stderr = sys.stderr, devnull | |
| try: yield | |
| finally: | |
| if stdout: sys.stdout = old_stdout | |
| if stderr: sys.stderr = old_stderr | |
| def toggle_grad(model, requires_grad): | |
| for p in model.parameters(): | |
| p.requires_grad_(requires_grad) | |
| def compute_grad2(d_outs, x_in): | |
| d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs | |
| reg = 0 | |
| for d_out in d_outs: | |
| batch_size = x_in.size(0) | |
| grad_dout = torch.autograd.grad( | |
| outputs=d_out.sum(), inputs=x_in, | |
| create_graph=True, retain_graph=True, only_inputs=True | |
| )[0] | |
| grad_dout2 = grad_dout.pow(2) | |
| assert(grad_dout2.size() == x_in.size()) | |
| reg += grad_dout2.view(batch_size, -1).sum(1) | |
| return reg / len(d_outs) | |
| # import matplotlib.pyplot as plt | |
| def interpolate_depth(depth_input, mask_input, size, bg_depth=20): | |
| assert len(depth_input.shape) == len(mask_input.shape) == 4 | |
| mask = (mask_input > 0.5).float() | |
| depth_valid = depth_input * mask | |
| depth_valid = torch_F.interpolate(depth_valid, size, mode='bilinear', align_corners=False) | |
| mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False) | |
| depth_out = depth_valid / (mask + 1.e-6) | |
| mask_binary = (mask > 0.5).float() | |
| depth_out = depth_out * mask_binary + bg_depth * (1 - mask_binary) | |
| return depth_out, mask_binary | |
| # import matplotlib.pyplot as plt | |
| # import torchvision | |
| def interpolate_coordmap(coord_map, mask_input, size, bg_coord=0): | |
| assert len(coord_map.shape) == len(mask_input.shape) == 4 | |
| mask = (mask_input > 0.5).float() | |
| coord_valid = coord_map * mask | |
| coord_valid = torch_F.interpolate(coord_valid, size, mode='bilinear', align_corners=False) | |
| mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False) | |
| coord_out = coord_valid / (mask + 1.e-6) | |
| mask_binary = (mask > 0.5).float() | |
| coord_out = coord_out * mask_binary + bg_coord * (1 - mask_binary) | |
| return coord_out, mask_binary | |
| def cleanup(): | |
| dist.destroy_process_group() | |
| def is_port_in_use(port): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| return s.connect_ex(('localhost', port)) == 0 | |
| def setup(rank, world_size, port_no): | |
| full_address = 'tcp://127.0.0.1:' + str(port_no) | |
| dist.init_process_group("nccl", init_method=full_address, rank=rank, world_size=world_size) | |
| def print_grad(grad, prefix=''): | |
| print("{} --- Grad Abs Mean, Grad Max, Grad Min: {:.5f} | {:.5f} | {:.5f}".format(prefix, grad.abs().mean().item(), grad.max().item(), grad.min().item())) | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| class EasyDict(dict): | |
| def __init__(self, d=None, **kwargs): | |
| if d is None: | |
| d = {} | |
| else: | |
| d = dict(d) | |
| if kwargs: | |
| d.update(**kwargs) | |
| for k, v in d.items(): | |
| setattr(self, k, v) | |
| # Class attributes | |
| for k in self.__class__.__dict__.keys(): | |
| if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): | |
| setattr(self, k, getattr(self, k)) | |
| def __setattr__(self, name, value): | |
| if isinstance(value, (list, tuple)): | |
| value = [self.__class__(x) | |
| if isinstance(x, dict) else x for x in value] | |
| elif isinstance(value, dict) and not isinstance(value, self.__class__): | |
| value = self.__class__(value) | |
| super(EasyDict, self).__setattr__(name, value) | |
| super(EasyDict, self).__setitem__(name, value) | |
| __setitem__ = __setattr__ | |
| def update(self, e=None, **f): | |
| d = e or dict() | |
| d.update(f) | |
| for k in d: | |
| setattr(self, k, d[k]) | |
| def pop(self, k, d=None): | |
| delattr(self, k) | |
| return super(EasyDict, self).pop(k, d) | |