File size: 5,888 Bytes
424919d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import argparse
import os
import sys
from abc import ABC
from typing import Type
class DefaultConfigs(ABC):
####### base setting ######
gpus = [0]
seed = 3407
arch = "resnet50"
datasets = [""]
datasets_test = [""]
class_bal = False
batch_size = 256
val_every = 1
loadSize = 256
cropSize = 224
epoch = "latest"
num_workers = 2
isTrain = True
####### train setting ######
warmup = False
warmup_epoch = 3
earlystop = True
earlystop_epoch = 5
optim = "adam"
new_optim = False
loss_freq = 400
save_latest_freq = 2000
save_epoch_freq = 20
continue_train = False
epoch_count = 1
last_epoch = -1
nepoch = 20
beta1 = 0.9
lr = 0.00001
init_type = "normal"
init_gain = 0.02
pretrained = True
# paths information
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
dataset_root = os.path.join(root_dir, "datasets")
dataset_test_root = os.path.join(root_dir, "datasets")
exp_root = os.path.join(root_dir, "experiments")
_exp_name = ""
exp_dir = ""
ckpt_dir = ""
logs_path = ""
ckpt_path = ""
pretrained_weights = ""
kd = True
kd_weight = 1.
reproduce_dire = False
only_eps = False
only_img = False
@property
def exp_name(self):
return self._exp_name
@exp_name.setter
def exp_name(self, value: str):
self._exp_name = value
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
if self.isTrain:
os.makedirs(self.exp_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
def to_dict(self):
dic = {}
for fieldkey in dir(self):
fieldvalue = getattr(self, fieldkey)
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
dic[fieldkey] = fieldvalue
return dic
def args_list2dict(arg_list: list):
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
return dict(zip(arg_list[::2], arg_list[1::2]))
def str2bool(v: str) -> bool:
if isinstance(v, bool):
return v
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
return True
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
return False
else:
return bool(v)
def str2list(v: str, element_type=None) -> list:
if not isinstance(v, (list, tuple, set)):
v = v.lstrip("[").rstrip("]")
v = v.split(",")
v = list(map(str.strip, v))
if element_type is not None:
v = list(map(element_type, v))
return v
CONFIGCLASS = Type[DefaultConfigs]
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default="0", type=str)
parser.add_argument("--batch", default=256, type=int)
parser.add_argument("--epoch", default="100", type=int)
parser.add_argument("--exp_name", default="", type=str)
parser.add_argument("--datasets", default="", type=str)
parser.add_argument("--dataset_root", default="", type=str)
parser.add_argument("--datasets_test", default="", type=str)
parser.add_argument("--dataset_test_root", default="", type=str)
parser.add_argument("--pretrained_weights", default="", type=str)
parser.add_argument("--lr", default=0.00001, type=float)
parser.add_argument("--kd_weight", default=1., type=float)
parser.add_argument("--test", default=False, type=str2bool)
parser.add_argument("--reproduce_dire", default=False, type=str2bool)
parser.add_argument("--only_eps", default=False, type=str2bool)
parser.add_argument("--only_img", default=False, type=str2bool)
parser.add_argument("--kd", default=True, type=str2bool)
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
args = parser.parse_args()
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
from config import cfg
cfg: CONFIGCLASS
else:
cfg = DefaultConfigs()
if args.opts:
opts = args_list2dict(args.opts)
for k, v in opts.items():
if not hasattr(cfg, k):
raise ValueError(f"Unrecognized option: {k}")
original_type = type(getattr(cfg, k))
if original_type == bool:
setattr(cfg, k, str2bool(v))
elif original_type in (list, tuple, set):
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
else:
setattr(cfg, k, original_type(v))
#cfg.gpus: list = args.gpus
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus#", ".join([str(gpu) for gpu in cfg.gpus])
if args.test:
cfg.isTrain = False
cfg.exp_name = args.exp_name
cfg.batch_size = args.batch
cfg.datasets = args.datasets
cfg.datasets_test = args.datasets_test if args.datasets_test else args.datasets
cfg.pretrained_weights = args.pretrained_weights
cfg.lr = args.lr
cfg.nepoch = args.epoch
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
cfg.dataset_root = os.path.join(root_dir, 'datasets', cfg.datasets)
if args.dataset_root != "":
cfg.dataset_root = args.dataset_root
cfg.dataset_test_root = os.path.join(root_dir, 'datasets', cfg.datasets_test)
if args.dataset_test_root != "":
cfg.dataset_test_root = args.dataset_test_root
cfg.kd = args.kd
cfg.reproduce_dire = args.reproduce_dire
cfg.only_eps = args.only_eps
cfg.only_img = args.only_img
cfg.kd_weight = args.kd_weight
# if isinstance(cfg.datasets, str):
# cfg.datasets = cfg.datasets.split(",")
|