Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # Swin Transformer | |
| # Copyright (c) 2021 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Ze Liu | |
| # --------------------------------------------------------' | |
| import os | |
| import yaml | |
| from yacs.config import CfgNode as CN | |
| _C = CN() | |
| # Base config files | |
| _C.BASE = [''] | |
| # ----------------------------------------------------------------------------- | |
| # Data settings | |
| # ----------------------------------------------------------------------------- | |
| _C.DATA = CN() | |
| # Batch size for a single GPU, could be overwritten by command line argument | |
| _C.DATA.BATCH_SIZE = 32 | |
| # Path to dataset, could be overwritten by command line argument | |
| _C.DATA.DATA_PATH = '' | |
| # Dataset name | |
| _C.DATA.DATASET = 'imagenet' | |
| # Dataset root folder | |
| _C.DATA.DATASET_ROOT = None | |
| # Input image size | |
| _C.DATA.IMG_SIZE = 224 | |
| # Interpolation to resize image (random, bilinear, bicubic) | |
| _C.DATA.INTERPOLATION = 'bicubic' | |
| _C.DATA.TRAIN_INTERPOLATION = 'bicubic' | |
| # Use zipped dataset instead of folder dataset | |
| # could be overwritten by command line argument | |
| _C.DATA.ZIP_MODE = False | |
| # Cache Data in Memory, could be overwritten by command line argument | |
| _C.DATA.CACHE_MODE = 'part' | |
| # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. | |
| _C.DATA.PIN_MEMORY = True | |
| # Number of data loading threads | |
| _C.DATA.NUM_WORKERS = 4 | |
| # hdfs data dir | |
| _C.DATA.TRAIN_PATH = None | |
| _C.DATA.VAL_PATH = None | |
| # arnold dataset parallel | |
| _C.DATA.NUM_READERS = 4 | |
| #meta info | |
| _C.DATA.ADD_META = False | |
| _C.DATA.FUSION = 'early' | |
| _C.DATA.MASK_PROB = 0.0 | |
| _C.DATA.MASK_TYPE = 'constant' | |
| _C.DATA.LATE_FUSION_LAYER = -1 | |
| # ----------------------------------------------------------------------------- | |
| # Model settings | |
| # ----------------------------------------------------------------------------- | |
| _C.MODEL = CN() | |
| # Model type | |
| _C.MODEL.TYPE = '' | |
| # Model name | |
| _C.MODEL.NAME = '' | |
| # Checkpoint to resume, could be overwritten by command line argument | |
| _C.MODEL.RESUME = '' | |
| # Number of classes, overwritten in data preparation | |
| _C.MODEL.NUM_CLASSES = 1000 | |
| # Dropout rate | |
| _C.MODEL.DROP_RATE = 0.0 | |
| # Drop path rate | |
| _C.MODEL.DROP_PATH_RATE = 0.1 | |
| # Label Smoothing | |
| _C.MODEL.LABEL_SMOOTHING = 0.1 | |
| #pretrain | |
| _C.MODEL.PRETRAINED = None | |
| _C.MODEL.DORP_HEAD = True | |
| _C.MODEL.DORP_META = True | |
| _C.MODEL.FREEZE_BACKBONE = True | |
| _C.MODEL.ONLY_LAST_CLS = False | |
| _C.MODEL.EXTRA_TOKEN_NUM = 1 | |
| _C.MODEL.META_DIMS = [] | |
| # ----------------------------------------------------------------------------- | |
| # Training settings | |
| # ----------------------------------------------------------------------------- | |
| _C.TRAIN = CN() | |
| _C.TRAIN.START_EPOCH = 0 | |
| _C.TRAIN.EPOCHS = 300 | |
| _C.TRAIN.WARMUP_EPOCHS = 20 | |
| _C.TRAIN.WEIGHT_DECAY = 0.05 | |
| _C.TRAIN.BASE_LR = 1e-4 # 5e-4 | |
| _C.TRAIN.WARMUP_LR = 5e-7 | |
| _C.TRAIN.MIN_LR = 1e-5 # 5e-6 | |
| # Clip gradient norm | |
| _C.TRAIN.CLIP_GRAD = 5.0 | |
| # Auto resume from latest checkpoint | |
| _C.TRAIN.AUTO_RESUME = True | |
| # Gradient accumulation steps | |
| # could be overwritten by command line argument | |
| _C.TRAIN.ACCUMULATION_STEPS = 0 | |
| # Whether to use gradient checkpointing to save memory | |
| # could be overwritten by command line argument | |
| _C.TRAIN.USE_CHECKPOINT = False | |
| # LR scheduler | |
| _C.TRAIN.LR_SCHEDULER = CN() | |
| _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' | |
| # Epoch interval to decay LR, used in StepLRScheduler | |
| _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 | |
| # LR decay rate, used in StepLRScheduler | |
| _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 | |
| # Optimizer | |
| _C.TRAIN.OPTIMIZER = CN() | |
| _C.TRAIN.OPTIMIZER.NAME = 'adamw' | |
| # Optimizer Epsilon | |
| _C.TRAIN.OPTIMIZER.EPS = 1e-8 | |
| # Optimizer Betas | |
| _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) | |
| # SGD momentum | |
| _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 | |
| # ----------------------------------------------------------------------------- | |
| # Augmentation settings | |
| # ----------------------------------------------------------------------------- | |
| _C.AUG = CN() | |
| # Color jitter factor | |
| _C.AUG.COLOR_JITTER = 0.4 | |
| # Use AutoAugment policy. "v0" or "original" | |
| _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' | |
| # Random erase prob | |
| _C.AUG.REPROB = 0.25 | |
| # Random erase mode | |
| _C.AUG.REMODE = 'pixel' | |
| # Random erase count | |
| _C.AUG.RECOUNT = 1 | |
| # Mixup alpha, mixup enabled if > 0 | |
| _C.AUG.MIXUP = 0.8 | |
| # Cutmix alpha, cutmix enabled if > 0 | |
| _C.AUG.CUTMIX = 1.0 | |
| # Cutmix min/max ratio, overrides alpha and enables cutmix if set | |
| _C.AUG.CUTMIX_MINMAX = None | |
| # Probability of performing mixup or cutmix when either/both is enabled | |
| _C.AUG.MIXUP_PROB = 1.0 | |
| # Probability of switching to cutmix when both mixup and cutmix enabled | |
| _C.AUG.MIXUP_SWITCH_PROB = 0.5 | |
| # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" | |
| _C.AUG.MIXUP_MODE = 'batch' | |
| # ----------------------------------------------------------------------------- | |
| # Testing settings | |
| # ----------------------------------------------------------------------------- | |
| _C.TEST = CN() | |
| # Whether to use center crop when testing | |
| _C.TEST.CROP = True | |
| # ----------------------------------------------------------------------------- | |
| # Misc | |
| # ----------------------------------------------------------------------------- | |
| # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') | |
| # overwritten by command line argument | |
| _C.AMP_OPT_LEVEL = '' | |
| # Path to output folder, overwritten by command line argument | |
| _C.OUTPUT = '' | |
| # Tag of experiment, overwritten by command line argument | |
| _C.TAG = 'default' | |
| # Frequency to save checkpoint | |
| _C.SAVE_FREQ = 1 | |
| # Frequency to logging info | |
| _C.PRINT_FREQ = 10 | |
| # Fixed random seed | |
| _C.SEED = 0 | |
| # Perform evaluation only, overwritten by command line argument | |
| _C.EVAL_MODE = False | |
| # Test throughput only, overwritten by command line argument | |
| _C.THROUGHPUT_MODE = False | |
| # local rank for DistributedDataParallel, given by command line argument | |
| _C.LOCAL_RANK = 0 | |
| def _update_config_from_file(config, cfg_file): | |
| config.defrost() | |
| with open(cfg_file, 'r') as f: | |
| yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| for cfg in yaml_cfg.setdefault('BASE', ['']): | |
| if cfg: | |
| _update_config_from_file( | |
| config, os.path.join(os.path.dirname(cfg_file), cfg) | |
| ) | |
| print('=> merge config from {}'.format(cfg_file)) | |
| config.merge_from_file(cfg_file) | |
| config.freeze() | |
| def update_config(config, args): | |
| _update_config_from_file(config, args.cfg) | |
| config.defrost() | |
| if args.opts: | |
| config.merge_from_list(args.opts) | |
| # merge from specific arguments | |
| if args.batch_size: | |
| config.DATA.BATCH_SIZE = args.batch_size | |
| if args.data_path: | |
| config.DATA.DATA_PATH = args.data_path | |
| if args.zip: | |
| config.DATA.ZIP_MODE = True | |
| if args.cache_mode: | |
| config.DATA.CACHE_MODE = args.cache_mode | |
| if args.resume: | |
| config.MODEL.RESUME = args.resume | |
| if args.accumulation_steps: | |
| config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps | |
| if args.use_checkpoint: | |
| config.TRAIN.USE_CHECKPOINT = True | |
| if args.amp_opt_level: | |
| config.AMP_OPT_LEVEL = args.amp_opt_level | |
| if args.output: | |
| config.OUTPUT = args.output | |
| if args.tag: | |
| config.TAG = args.tag | |
| if args.eval: | |
| config.EVAL_MODE = True | |
| if args.throughput: | |
| config.THROUGHPUT_MODE = True | |
| if args.num_workers is not None: | |
| config.DATA.NUM_WORKERS = args.num_workers | |
| #set lr and weight decay | |
| if args.lr is not None: | |
| config.TRAIN.BASE_LR = args.lr | |
| if args.min_lr is not None: | |
| config.TRAIN.MIN_LR = args.min_lr | |
| if args.warmup_lr is not None: | |
| config.TRAIN.WARMUP_LR = args.warmup_lr | |
| if args.warmup_epochs is not None: | |
| config.TRAIN.WARMUP_EPOCHS = args.warmup_epochs | |
| if args.weight_decay is not None: | |
| config.TRAIN.WEIGHT_DECAY = args.weight_decay | |
| if args.epochs is not None: | |
| config.TRAIN.EPOCHS = args.epochs | |
| if args.dataset is not None: | |
| config.DATA.DATASET = args.dataset | |
| if args.lr_scheduler_name is not None: | |
| config.TRAIN.LR_SCHEDULER.NAME = args.lr_scheduler_name | |
| if args.pretrain is not None: | |
| config.MODEL.PRETRAINED = args.pretrain | |
| # set local rank for distributed training | |
| config.LOCAL_RANK = os.environ['LOCAL_RANK'] | |
| # output folder | |
| config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) | |
| config.freeze() | |
| def get_config(args): | |
| """Get a yacs CfgNode object with default values.""" | |
| # Return a clone so that the defaults will not be altered | |
| # This is for the "local variable" use pattern | |
| config = _C.clone() | |
| update_config(config, args) | |
| return config | |
| ################### For Inferencing #################### | |
| def update_inference_config(config, args): | |
| _update_config_from_file(config, args.cfg) | |
| config.defrost() | |
| config.freeze() | |
| def get_inference_config(cfg_path): | |
| """Get a yacs CfgNode object with default values.""" | |
| # Return a clone so that the defaults will not be altered | |
| # This is for the "local variable" use pattern | |
| config = _C.clone() | |
| update_inference_config(config, cfg_path) | |
| return config | |
| ################### For Inferencing #################### | |