| | import os |
| | import math |
| |
|
| |
|
| | class Config: |
| | def __init__(self) -> None: |
| | |
| | |
| | self.sys_home_dir = [os.path.expanduser("~"), "/mnt/data"][0] |
| | self.data_root_dir = os.path.join(self.sys_home_dir, "datasets/dis") |
| |
|
| | |
| | self.task = ["DIS5K", "COD", "HRSOD", "General", "General-2K", "Matting"][0] |
| | self.testsets = { |
| | |
| | "DIS5K": ",".join( |
| | ["DIS-VD", "DIS-TE1", "DIS-TE2", "DIS-TE3", "DIS-TE4"][:1] |
| | ), |
| | "COD": ",".join(["CHAMELEON", "NC4K", "TE-CAMO", "TE-COD10K"]), |
| | "HRSOD": ",".join( |
| | ["DAVIS-S", "TE-HRSOD", "TE-UHRSD", "DUT-OMRON", "TE-DUTS"] |
| | ), |
| | |
| | "General": ",".join(["DIS-VD", "TE-P3M-500-NP"]), |
| | "General-2K": ",".join(["DIS-VD", "TE-P3M-500-NP"]), |
| | "Matting": ",".join(["TE-P3M-500-NP", "TE-AM-2k"]), |
| | }[self.task] |
| | datasets_all = "+".join( |
| | [ |
| | ds |
| | for ds in ( |
| | os.listdir(os.path.join(self.data_root_dir, self.task)) |
| | if os.path.isdir(os.path.join(self.data_root_dir, self.task)) |
| | else [] |
| | ) |
| | if ds not in self.testsets.split(",") |
| | ] |
| | ) |
| | self.training_set = { |
| | "DIS5K": ["DIS-TR", "DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4"][0], |
| | "COD": "TR-COD10K+TR-CAMO", |
| | "HRSOD": [ |
| | "TR-DUTS", |
| | "TR-HRSOD", |
| | "TR-UHRSD", |
| | "TR-DUTS+TR-HRSOD", |
| | "TR-DUTS+TR-UHRSD", |
| | "TR-HRSOD+TR-UHRSD", |
| | "TR-DUTS+TR-HRSOD+TR-UHRSD", |
| | ][5], |
| | "General": datasets_all, |
| | "General-2K": datasets_all, |
| | "Matting": datasets_all, |
| | }[self.task] |
| | self.prompt4loc = ["dense", "sparse"][0] |
| |
|
| | |
| | self.load_all = False |
| | self.compile = True |
| | |
| | |
| | |
| | self.precisionHigh = True |
| |
|
| | |
| | self.ms_supervision = True |
| | self.out_ref = self.ms_supervision and True |
| | self.dec_ipt = True |
| | self.dec_ipt_split = True |
| | self.cxt_num = [0, 3][1] |
| | self.mul_scl_ipt = ["", "add", "cat"][2] |
| | self.dec_att = ["", "ASPP", "ASPPDeformable"][2] |
| | self.squeeze_block = [ |
| | "", |
| | "BasicDecBlk_x1", |
| | "ResBlk_x4", |
| | "ASPP_x3", |
| | "ASPPDeformable_x3", |
| | ][1] |
| | self.dec_blk = ["BasicDecBlk", "ResBlk"][0] |
| |
|
| | |
| | self.batch_size = 4 |
| | self.finetune_last_epochs = [ |
| | 0, |
| | { |
| | "DIS5K": -40, |
| | "COD": -20, |
| | "HRSOD": -20, |
| | "General": -40, |
| | "General-2K": -20, |
| | "Matting": -20, |
| | }[self.task], |
| | ][ |
| | 1 |
| | ] |
| | self.lr = (1e-4 if "DIS5K" in self.task else 1e-5) * math.sqrt( |
| | self.batch_size / 4 |
| | ) |
| | self.size = ( |
| | (1024, 1024) if self.task not in ["General-2K"] else (2560, 1440) |
| | ) |
| | self.num_workers = max( |
| | 4, self.batch_size |
| | ) |
| |
|
| | |
| | self.bb = [ |
| | "vgg16", |
| | "vgg16bn", |
| | "resnet50", |
| | "swin_v1_t", |
| | "swin_v1_s", |
| | "swin_v1_b", |
| | "swin_v1_l", |
| | "pvt_v2_b0", |
| | "pvt_v2_b1", |
| | "pvt_v2_b2", |
| | "pvt_v2_b5", |
| | ][6] |
| | self.lateral_channels_in_collection = { |
| | "vgg16": [512, 256, 128, 64], |
| | "vgg16bn": [512, 256, 128, 64], |
| | "resnet50": [1024, 512, 256, 64], |
| | "pvt_v2_b2": [512, 320, 128, 64], |
| | "pvt_v2_b5": [512, 320, 128, 64], |
| | "swin_v1_b": [1024, 512, 256, 128], |
| | "swin_v1_l": [1536, 768, 384, 192], |
| | "swin_v1_t": [768, 384, 192, 96], |
| | "swin_v1_s": [768, 384, 192, 96], |
| | "pvt_v2_b0": [256, 160, 64, 32], |
| | "pvt_v2_b1": [512, 320, 128, 64], |
| | }[self.bb] |
| | if self.mul_scl_ipt == "cat": |
| | self.lateral_channels_in_collection = [ |
| | channel * 2 for channel in self.lateral_channels_in_collection |
| | ] |
| | self.cxt = ( |
| | self.lateral_channels_in_collection[1:][::-1][-self.cxt_num :] |
| | if self.cxt_num |
| | else [] |
| | ) |
| |
|
| | |
| | self.lat_blk = ["BasicLatBlk"][0] |
| | self.dec_channels_inter = ["fixed", "adap"][0] |
| | self.refine = ["", "itself", "RefUNet", "Refiner", "RefinerPVTInChannels4"][0] |
| | self.progressive_ref = self.refine and True |
| | self.ender = self.progressive_ref and False |
| | self.scale = self.progressive_ref and 2 |
| | self.auxiliary_classification = ( |
| | False |
| | ) |
| | self.refine_iteration = 1 |
| | self.freeze_bb = False |
| | self.model = [ |
| | "BiRefNet", |
| | "BiRefNetC2F", |
| | ][0] |
| |
|
| | |
| | self.preproc_methods = ["flip", "enhance", "rotate", "pepper", "crop"][:4] |
| | self.optimizer = ["Adam", "AdamW"][1] |
| | self.lr_decay_epochs = [ |
| | 1e5 |
| | ] |
| | self.lr_decay_rate = 0.5 |
| | |
| | if self.task in ["Matting"]: |
| | self.lambdas_pix_last = { |
| | "bce": 30 * 1, |
| | "iou": 0.5 * 0, |
| | "iou_patch": 0.5 * 0, |
| | "mae": 100 * 1, |
| | "mse": 30 * 0, |
| | "triplet": 3 * 0, |
| | "reg": 100 * 0, |
| | "ssim": 10 * 1, |
| | "cnt": 5 * 0, |
| | "structure": 5 * 0, |
| | } |
| | elif self.task in ["General", "General-2K"]: |
| | self.lambdas_pix_last = { |
| | "bce": 30 * 1, |
| | "iou": 0.5 * 1, |
| | "iou_patch": 0.5 * 0, |
| | "mae": 100 * 1, |
| | "mse": 30 * 0, |
| | "triplet": 3 * 0, |
| | "reg": 100 * 0, |
| | "ssim": 10 * 1, |
| | "cnt": 5 * 0, |
| | "structure": 5 * 0, |
| | } |
| | else: |
| | self.lambdas_pix_last = { |
| | |
| | |
| | "bce": 30 * 1, |
| | "iou": 0.5 * 1, |
| | "iou_patch": 0.5 * 0, |
| | "mae": 30 * 0, |
| | "mse": 30 * 0, |
| | "triplet": 3 * 0, |
| | "reg": 100 * 0, |
| | "ssim": 10 * 1, |
| | "cnt": 5 * 0, |
| | "structure": 5 |
| | * 0, |
| | } |
| | self.lambdas_cls = {"ce": 5.0} |
| |
|
| | |
| | self.weights_root_dir = os.path.join(self.sys_home_dir, "weights/cv") |
| | self.weights = { |
| | "pvt_v2_b2": os.path.join(self.weights_root_dir, "pvt_v2_b2.pth"), |
| | "pvt_v2_b5": os.path.join( |
| | self.weights_root_dir, ["pvt_v2_b5.pth", "pvt_v2_b5_22k.pth"][0] |
| | ), |
| | "swin_v1_b": os.path.join( |
| | self.weights_root_dir, |
| | [ |
| | "swin_base_patch4_window12_384_22kto1k.pth", |
| | "swin_base_patch4_window12_384_22k.pth", |
| | ][0], |
| | ), |
| | "swin_v1_l": os.path.join( |
| | self.weights_root_dir, |
| | [ |
| | "swin_large_patch4_window12_384_22kto1k.pth", |
| | "swin_large_patch4_window12_384_22k.pth", |
| | ][0], |
| | ), |
| | "swin_v1_t": os.path.join( |
| | self.weights_root_dir, |
| | ["swin_tiny_patch4_window7_224_22kto1k_finetune.pth"][0], |
| | ), |
| | "swin_v1_s": os.path.join( |
| | self.weights_root_dir, |
| | ["swin_small_patch4_window7_224_22kto1k_finetune.pth"][0], |
| | ), |
| | "pvt_v2_b0": os.path.join(self.weights_root_dir, ["pvt_v2_b0.pth"][0]), |
| | "pvt_v2_b1": os.path.join(self.weights_root_dir, ["pvt_v2_b1.pth"][0]), |
| | } |
| |
|
| | |
| | self.verbose_eval = True |
| | self.only_S_MAE = False |
| | self.SDPA_enabled = False |
| |
|
| | |
| | self.device = [0, "cpu"][0] |
| |
|
| | self.batch_size_valid = 1 |
| | self.rand_seed = 7 |
| | run_sh_file = [f for f in os.listdir(".") if "train.sh" == f] + [ |
| | os.path.join("..", f) for f in os.listdir("..") if "train.sh" == f |
| | ] |
| | if run_sh_file: |
| | with open(run_sh_file[0], "r") as f: |
| | lines = f.readlines() |
| | self.save_last = int( |
| | [ |
| | l.strip() |
| | for l in lines |
| | if "'{}')".format(self.task) in l and "val_last=" in l |
| | ][0] |
| | .split("val_last=")[-1] |
| | .split()[0] |
| | ) |
| | self.save_step = int( |
| | [ |
| | l.strip() |
| | for l in lines |
| | if "'{}')".format(self.task) in l and "step=" in l |
| | ][0] |
| | .split("step=")[-1] |
| | .split()[0] |
| | ) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser( |
| | description="Only choose one argument to activate." |
| | ) |
| | parser.add_argument("--print_task", action="store_true", help="print task name") |
| | parser.add_argument( |
| | "--print_testsets", action="store_true", help="print validation set" |
| | ) |
| | args = parser.parse_args() |
| |
|
| | config = Config() |
| | for arg_name, arg_value in args._get_kwargs(): |
| | if arg_value: |
| | print(config.__getattribute__(arg_name[len("print_") :])) |
| |
|