File size: 2,378 Bytes
02ba886 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import random
import numpy as np
import errno
import os
import os.path as osp
import torch
def set_random_seed(seed):
if seed > 0 :
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(fpath):
"""Check if the given path is a file.
Args:
fpath (str): file path.
Returns:
bool
"""
isfile = osp.isfile(fpath)
if not isfile:
print('No file found at "{}"'.format(fpath))
return isfile
def get_named_submodule(model, sub_name: str):
names = sub_name.split(".")
module = model
for name in names:
module = getattr(module, name)
return module
def set_named_submodule(model, sub_name, value):
names = sub_name.split(".")
module = model
for i in range(len(names)):
if i != len(names) - 1:
module = getattr(module, names[i])
else:
setattr(module, names[i], value)
class AverageMeterMultiTargets(object):
def __init__(self, target_names):
self.target_names = target_names
self.reset()
def reset(self):
self.count = {}
self.sum = {}
def update(self,acc, dom,n=1):
for d in torch.unique(dom).numpy():
if d not in self.count:
self.count[d] = 0.0
self.sum[d] = 0.0
arr = acc[dom == d]
self.count[d] += len(arr)
self.sum[d] += arr.sum()
def average(self, key="index"):
assert key in ["index", "name"]
avg = {i if key == "index" else self.target_names[i]: (self.sum[i] * 1.0 /self.count[i]).item() for i in self.count}
avg["avg"] = np.array([avg[k] for k in avg]).mean()
return avg
def __repr__(self):
acc = self.average()
s = ""
for i in self.target_names:
if i not in acc:
continue
s += "%20s: %.04f" % (
self.target_names[i],
acc[i]
)
s += "\n"
s += "-" * 30 + "\n"
s += f'>>> TARGET average: %.04f' % acc["avg"]
return s
def to_dict(self):
return self.average(key="name") |