| | |
| | |
| | |
| | from typing import Dict, List, Any |
| | from datetime import datetime |
| | from itertools import chain |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.distributed as dist |
| | import numpy as np |
| |
|
| | |
| |
|
| | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
| | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
| |
|
| | DEFAULT_MEAN = IMAGENET_DEFAULT_MEAN |
| | DEFAULT_STD = IMAGENET_DEFAULT_STD |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def unnorm(x): |
| | mean = torch.as_tensor(DEFAULT_MEAN, device=x.device)[None, ..., None, None] |
| | std = torch.as_tensor(DEFAULT_STD, device=x.device)[None, ..., None, None] |
| | return x.mul(std).add(mean) |
| |
|
| |
|
| | |
| | def check_nonfinite(x, name=""): |
| | rank = dist.get_rank() |
| | n_nan = x.isnan().sum() |
| | n_inf = x.isinf().sum() |
| | if n_nan or n_inf: |
| | print(f"[RANK {rank}] {name} is not finite: #nan={n_nan}, #inf={n_inf}") |
| | return True |
| |
|
| | print(f"[RANK {rank}] {name} is OK ...") |
| | return False |
| |
|
| |
|
| | def normalize(t, dim, eps=1e-6): |
| | """Large default eps for fp16""" |
| | return F.normalize(t, dim=dim, eps=eps) |
| |
|
| |
|
| | def timestamp(fmt="%y%m%d-%H%M%S"): |
| | return datetime.now().strftime(fmt) |
| |
|
| |
|
| | def merge_dicts_by_key(dics: List[Dict]) -> Dict[Any, List]: |
| | """Merge dictionaries by key. All of dicts must have same keys.""" |
| | ret = {key: [] for key in dics[0].keys()} |
| | for dic in dics: |
| | for key, value in dic.items(): |
| | ret[key].append(value) |
| |
|
| | return ret |
| |
|
| |
|
| | def flatten_2d_list(list2d): |
| | return list(chain.from_iterable(list2d)) |
| |
|
| |
|
| | def num_params(module): |
| | return sum(p.numel() for p in module.parameters()) |
| |
|
| |
|
| | def param_trace(name, module, depth=0, max_depth=999, threshold=0, printf=print): |
| | if depth > max_depth: |
| | return |
| | prefix = " " * depth |
| | n_params = num_params(module) |
| | if n_params > threshold: |
| | printf("{:60s}\t{:10.3f}M".format(prefix + name, n_params / 1024 / 1024)) |
| | for n, m in module.named_children(): |
| | if depth == 0: |
| | child_name = n |
| | else: |
| | child_name = "{}.{}".format(name, n) |
| | param_trace(child_name, m, depth + 1, max_depth, threshold, printf) |
| |
|
| |
|
| | @torch.no_grad() |
| | def hash_bn(module): |
| | summary = [] |
| | for m in module.modules(): |
| | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): |
| | w = m.weight.detach().mean().item() |
| | b = m.bias.detach().mean().item() |
| | rm = m.running_mean.detach().mean().item() |
| | rv = m.running_var.detach().mean().item() |
| | summary.append((w, b, rm, rv)) |
| |
|
| | if not summary: |
| | return 0.0, 0.0 |
| |
|
| | w, b, rm, rv = [np.mean(col) for col in zip(*summary)] |
| | p = np.mean([w, b]) |
| | s = np.mean([rm, rv]) |
| |
|
| | return p, s |
| |
|
| |
|
| | @torch.no_grad() |
| | def hash_params(module): |
| | return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item() |
| |
|
| |
|
| | @torch.no_grad() |
| | def hashm(module): |
| | p = hash_params(module) |
| | _, s = hash_bn(module) |
| |
|
| | return p, s |
| |
|