distildire / utils /config.py
Yewon Lim
first
424919d
import argparse
import os
import sys
from abc import ABC
from typing import Type
class DefaultConfigs(ABC):
####### base setting ######
gpus = [0]
seed = 3407
arch = "resnet50"
datasets = [""]
datasets_test = [""]
class_bal = False
batch_size = 256
val_every = 1
loadSize = 256
cropSize = 224
epoch = "latest"
num_workers = 2
isTrain = True
####### train setting ######
warmup = False
warmup_epoch = 3
earlystop = True
earlystop_epoch = 5
optim = "adam"
new_optim = False
loss_freq = 400
save_latest_freq = 2000
save_epoch_freq = 20
continue_train = False
epoch_count = 1
last_epoch = -1
nepoch = 20
beta1 = 0.9
lr = 0.00001
init_type = "normal"
init_gain = 0.02
pretrained = True
# paths information
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
dataset_root = os.path.join(root_dir, "datasets")
dataset_test_root = os.path.join(root_dir, "datasets")
exp_root = os.path.join(root_dir, "experiments")
_exp_name = ""
exp_dir = ""
ckpt_dir = ""
logs_path = ""
ckpt_path = ""
pretrained_weights = ""
kd = True
kd_weight = 1.
reproduce_dire = False
only_eps = False
only_img = False
@property
def exp_name(self):
return self._exp_name
@exp_name.setter
def exp_name(self, value: str):
self._exp_name = value
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
if self.isTrain:
os.makedirs(self.exp_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
def to_dict(self):
dic = {}
for fieldkey in dir(self):
fieldvalue = getattr(self, fieldkey)
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
dic[fieldkey] = fieldvalue
return dic
def args_list2dict(arg_list: list):
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
return dict(zip(arg_list[::2], arg_list[1::2]))
def str2bool(v: str) -> bool:
if isinstance(v, bool):
return v
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
return True
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
return False
else:
return bool(v)
def str2list(v: str, element_type=None) -> list:
if not isinstance(v, (list, tuple, set)):
v = v.lstrip("[").rstrip("]")
v = v.split(",")
v = list(map(str.strip, v))
if element_type is not None:
v = list(map(element_type, v))
return v
CONFIGCLASS = Type[DefaultConfigs]
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default="0", type=str)
parser.add_argument("--batch", default=256, type=int)
parser.add_argument("--epoch", default="100", type=int)
parser.add_argument("--exp_name", default="", type=str)
parser.add_argument("--datasets", default="", type=str)
parser.add_argument("--dataset_root", default="", type=str)
parser.add_argument("--datasets_test", default="", type=str)
parser.add_argument("--dataset_test_root", default="", type=str)
parser.add_argument("--pretrained_weights", default="", type=str)
parser.add_argument("--lr", default=0.00001, type=float)
parser.add_argument("--kd_weight", default=1., type=float)
parser.add_argument("--test", default=False, type=str2bool)
parser.add_argument("--reproduce_dire", default=False, type=str2bool)
parser.add_argument("--only_eps", default=False, type=str2bool)
parser.add_argument("--only_img", default=False, type=str2bool)
parser.add_argument("--kd", default=True, type=str2bool)
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
args = parser.parse_args()
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
from config import cfg
cfg: CONFIGCLASS
else:
cfg = DefaultConfigs()
if args.opts:
opts = args_list2dict(args.opts)
for k, v in opts.items():
if not hasattr(cfg, k):
raise ValueError(f"Unrecognized option: {k}")
original_type = type(getattr(cfg, k))
if original_type == bool:
setattr(cfg, k, str2bool(v))
elif original_type in (list, tuple, set):
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
else:
setattr(cfg, k, original_type(v))
#cfg.gpus: list = args.gpus
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus#", ".join([str(gpu) for gpu in cfg.gpus])
if args.test:
cfg.isTrain = False
cfg.exp_name = args.exp_name
cfg.batch_size = args.batch
cfg.datasets = args.datasets
cfg.datasets_test = args.datasets_test if args.datasets_test else args.datasets
cfg.pretrained_weights = args.pretrained_weights
cfg.lr = args.lr
cfg.nepoch = args.epoch
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
cfg.dataset_root = os.path.join(root_dir, 'datasets', cfg.datasets)
if args.dataset_root != "":
cfg.dataset_root = args.dataset_root
cfg.dataset_test_root = os.path.join(root_dir, 'datasets', cfg.datasets_test)
if args.dataset_test_root != "":
cfg.dataset_test_root = args.dataset_test_root
cfg.kd = args.kd
cfg.reproduce_dire = args.reproduce_dire
cfg.only_eps = args.only_eps
cfg.only_img = args.only_img
cfg.kd_weight = args.kd_weight
# if isinstance(cfg.datasets, str):
# cfg.datasets = cfg.datasets.split(",")