| |
| |
| |
| from typing import Dict, List, Any |
| from datetime import datetime |
| from itertools import chain |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| import numpy as np |
|
|
| |
|
|
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
| DEFAULT_MEAN = IMAGENET_DEFAULT_MEAN |
| DEFAULT_STD = IMAGENET_DEFAULT_STD |
|
|
| |
| |
| |
|
|
|
|
| def unnorm(x): |
| mean = torch.as_tensor(DEFAULT_MEAN, device=x.device)[None, ..., None, None] |
| std = torch.as_tensor(DEFAULT_STD, device=x.device)[None, ..., None, None] |
| return x.mul(std).add(mean) |
|
|
|
|
| |
| def check_nonfinite(x, name=""): |
| rank = dist.get_rank() |
| n_nan = x.isnan().sum() |
| n_inf = x.isinf().sum() |
| if n_nan or n_inf: |
| print(f"[RANK {rank}] {name} is not finite: #nan={n_nan}, #inf={n_inf}") |
| return True |
|
|
| print(f"[RANK {rank}] {name} is OK ...") |
| return False |
|
|
|
|
| def normalize(t, dim, eps=1e-6): |
| """Large default eps for fp16""" |
| return F.normalize(t, dim=dim, eps=eps) |
|
|
|
|
| def timestamp(fmt="%y%m%d-%H%M%S"): |
| return datetime.now().strftime(fmt) |
|
|
|
|
| def merge_dicts_by_key(dics: List[Dict]) -> Dict[Any, List]: |
| """Merge dictionaries by key. All of dicts must have same keys.""" |
| ret = {key: [] for key in dics[0].keys()} |
| for dic in dics: |
| for key, value in dic.items(): |
| ret[key].append(value) |
|
|
| return ret |
|
|
|
|
| def flatten_2d_list(list2d): |
| return list(chain.from_iterable(list2d)) |
|
|
|
|
| def num_params(module): |
| return sum(p.numel() for p in module.parameters()) |
|
|
|
|
| def param_trace(name, module, depth=0, max_depth=999, threshold=0, printf=print): |
| if depth > max_depth: |
| return |
| prefix = " " * depth |
| n_params = num_params(module) |
| if n_params > threshold: |
| printf("{:60s}\t{:10.3f}M".format(prefix + name, n_params / 1024 / 1024)) |
| for n, m in module.named_children(): |
| if depth == 0: |
| child_name = n |
| else: |
| child_name = "{}.{}".format(name, n) |
| param_trace(child_name, m, depth + 1, max_depth, threshold, printf) |
|
|
|
|
| @torch.no_grad() |
| def hash_bn(module): |
| summary = [] |
| for m in module.modules(): |
| if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): |
| w = m.weight.detach().mean().item() |
| b = m.bias.detach().mean().item() |
| rm = m.running_mean.detach().mean().item() |
| rv = m.running_var.detach().mean().item() |
| summary.append((w, b, rm, rv)) |
|
|
| if not summary: |
| return 0.0, 0.0 |
|
|
| w, b, rm, rv = [np.mean(col) for col in zip(*summary)] |
| p = np.mean([w, b]) |
| s = np.mean([rm, rv]) |
|
|
| return p, s |
|
|
|
|
| @torch.no_grad() |
| def hash_params(module): |
| return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item() |
|
|
|
|
| @torch.no_grad() |
| def hashm(module): |
| p = hash_params(module) |
| _, s = hash_bn(module) |
|
|
| return p, s |
|
|