|
|
import argparse
|
|
|
import os
|
|
|
import sys
|
|
|
from abc import ABC
|
|
|
from typing import Type
|
|
|
|
|
|
|
|
|
class DefaultConfigs(ABC):
|
|
|
|
|
|
gpus = [0]
|
|
|
seed = 3407
|
|
|
arch = "resnet50"
|
|
|
datasets = ["zhaolian_train"]
|
|
|
datasets_test = ["adm_res_abs_ddim20s"]
|
|
|
mode = "binary"
|
|
|
class_bal = False
|
|
|
batch_size = 64
|
|
|
loadSize = 256
|
|
|
cropSize = 224
|
|
|
epoch = "latest"
|
|
|
num_workers = 20
|
|
|
serial_batches = False
|
|
|
isTrain = True
|
|
|
|
|
|
|
|
|
rz_interp = ["bilinear"]
|
|
|
|
|
|
blur_prob = 0.1
|
|
|
blur_sig = [0.5]
|
|
|
|
|
|
jpg_prob = 0.1
|
|
|
jpg_method = ["cv2"]
|
|
|
jpg_qual = [75]
|
|
|
gray_prob = 0.0
|
|
|
aug_resize = True
|
|
|
aug_crop = True
|
|
|
aug_flip = True
|
|
|
aug_norm = True
|
|
|
|
|
|
|
|
|
warmup = False
|
|
|
|
|
|
warmup_epoch = 3
|
|
|
earlystop = True
|
|
|
earlystop_epoch = 5
|
|
|
optim = "adam"
|
|
|
new_optim = False
|
|
|
loss_freq = 400
|
|
|
save_latest_freq = 2000
|
|
|
save_epoch_freq = 20
|
|
|
continue_train = False
|
|
|
epoch_count = 1
|
|
|
last_epoch = -1
|
|
|
nepoch = 400
|
|
|
beta1 = 0.9
|
|
|
lr = 0.0001
|
|
|
init_type = "normal"
|
|
|
init_gain = 0.02
|
|
|
pretrained = True
|
|
|
|
|
|
|
|
|
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
dataset_root = os.path.join(root_dir, "data")
|
|
|
exp_root = os.path.join(root_dir, "data", "exp")
|
|
|
_exp_name = ""
|
|
|
exp_dir = ""
|
|
|
ckpt_dir = ""
|
|
|
logs_path = ""
|
|
|
ckpt_path = ""
|
|
|
|
|
|
@property
|
|
|
def exp_name(self):
|
|
|
return self._exp_name
|
|
|
|
|
|
@exp_name.setter
|
|
|
def exp_name(self, value: str):
|
|
|
self._exp_name = value
|
|
|
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
|
|
|
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
|
|
|
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
|
|
|
|
|
|
os.makedirs(self.exp_dir, exist_ok=True)
|
|
|
os.makedirs(self.ckpt_dir, exist_ok=True)
|
|
|
|
|
|
def to_dict(self):
|
|
|
dic = {}
|
|
|
for fieldkey in dir(self):
|
|
|
fieldvalue = getattr(self, fieldkey)
|
|
|
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
|
|
|
dic[fieldkey] = fieldvalue
|
|
|
return dic
|
|
|
|
|
|
|
|
|
def args_list2dict(arg_list: list):
|
|
|
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
|
|
|
return dict(zip(arg_list[::2], arg_list[1::2]))
|
|
|
|
|
|
|
|
|
def str2bool(v: str) -> bool:
|
|
|
if isinstance(v, bool):
|
|
|
return v
|
|
|
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
|
|
|
return True
|
|
|
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
|
|
|
return False
|
|
|
else:
|
|
|
return bool(v)
|
|
|
|
|
|
|
|
|
def str2list(v: str, element_type=None) -> list:
|
|
|
if not isinstance(v, (list, tuple, set)):
|
|
|
v = v.lstrip("[").rstrip("]")
|
|
|
v = v.split(",")
|
|
|
v = list(map(str.strip, v))
|
|
|
if element_type is not None:
|
|
|
v = list(map(element_type, v))
|
|
|
return v
|
|
|
|
|
|
|
|
|
CONFIGCLASS = Type[DefaultConfigs]
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--gpus", default=[0], type=int, nargs="+")
|
|
|
parser.add_argument("--exp_name", default="", type=str)
|
|
|
parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
|
|
|
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
|
|
|
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
|
|
|
from config import cfg
|
|
|
|
|
|
cfg: CONFIGCLASS
|
|
|
else:
|
|
|
cfg = DefaultConfigs()
|
|
|
|
|
|
if args.opts:
|
|
|
opts = args_list2dict(args.opts)
|
|
|
for k, v in opts.items():
|
|
|
if not hasattr(cfg, k):
|
|
|
raise ValueError(f"Unrecognized option: {k}")
|
|
|
original_type = type(getattr(cfg, k))
|
|
|
if original_type == bool:
|
|
|
setattr(cfg, k, str2bool(v))
|
|
|
elif original_type in (list, tuple, set):
|
|
|
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
|
|
|
else:
|
|
|
setattr(cfg, k, original_type(v))
|
|
|
|
|
|
cfg.gpus: list = args.gpus
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
|
|
|
cfg.exp_name = args.exp_name
|
|
|
cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
|
|
|
|
|
|
if isinstance(cfg.datasets, str):
|
|
|
cfg.datasets = cfg.datasets.split(",")
|
|
|
|