| | import os
|
| | import argparse
|
| | from mmengine import Config
|
| |
|
| | def create_deeplabv3plus_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
| | cfg = Config.fromfile(model_config_path)
|
| | dataset_cfg = Config.fromfile(dataset_config_path)
|
| | cfg.merge_from_dict(dataset_cfg)
|
| |
|
| |
|
| | cfg.crop_size = (512, 512)
|
| | cfg.model.data_preprocessor.size = cfg.crop_size
|
| |
|
| |
|
| | cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
| | cfg.model.backbone.norm_cfg = cfg.norm_cfg
|
| | cfg.model.decode_head.norm_cfg = cfg.norm_cfg
|
| | cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
|
| |
|
| | cfg.model.decode_head.num_classes = num_class
|
| | cfg.model.auxiliary_head.num_classes = num_class
|
| |
|
| | cfg.train_dataloader.batch_size = batch_size
|
| |
|
| |
|
| | cfg.train_cfg.max_iters = max_iters
|
| | cfg.train_cfg.val_interval = val_interval
|
| | cfg.default_hooks.logger.interval = 100
|
| | cfg.default_hooks.checkpoint.interval = 2500
|
| | cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
| | cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
| |
|
| | cfg['randomness'] = dict(seed=0)
|
| |
|
| | cfg.save_dir = save_dir
|
| | name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
| | cfg.work_dir = os.path.join(work_dir,name)
|
| | os.makedirs(cfg.work_dir, exist_ok=True)
|
| | save_config_file = os.path.join(save_dir, f"{name}.py")
|
| | cfg.dump(save_config_file)
|
| | print(f"Configuration saved to: {save_config_file}")
|
| |
|
| | def create_knet_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
| |
|
| | cfg = Config.fromfile(model_config_path)
|
| | dataset_cfg = Config.fromfile(dataset_config_path)
|
| |
|
| | cfg.merge_from_dict(dataset_cfg)
|
| |
|
| | cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
| | cfg.model.data_preprocessor.size = cfg.crop_size
|
| |
|
| | cfg.model.decode_head.kernel_generate_head.num_classes = num_class
|
| | cfg.model.auxiliary_head.num_classes = num_class
|
| |
|
| | cfg.train_dataloader.batch_size = batch_size
|
| | cfg.work_dir = work_dir
|
| |
|
| | cfg.train_cfg.max_iters = max_iters
|
| | cfg.train_cfg.val_interval = val_interval
|
| | cfg.default_hooks.logger.interval = 100
|
| | cfg.default_hooks.checkpoint.interval = 2500
|
| | cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
| | cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
| |
|
| | cfg['randomness'] = dict(seed=0)
|
| |
|
| | cfg.save_dir = save_dir
|
| | name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
| | cfg.work_dir = os.path.join(work_dir, name)
|
| | os.makedirs(cfg.work_dir, exist_ok=True)
|
| | save_config_file = os.path.join(save_dir, f"{name}.py")
|
| | cfg.dump(save_config_file)
|
| | print(f"Configuration saved to: {save_config_file}")
|
| |
|
| | def create_mask2former_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
| | cfg = Config.fromfile(model_config_path)
|
| | dataset_cfg = Config.fromfile(dataset_config_path)
|
| | cfg.merge_from_dict(dataset_cfg)
|
| |
|
| |
|
| | cfg.crop_size = (512, 512)
|
| | cfg.model.data_preprocessor.size = cfg.crop_size
|
| |
|
| |
|
| | cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
| |
|
| | cfg.model.decode_head.num_classes = num_class
|
| | cfg.model.decode_head.loss_cls.class_weight = [1.0] * num_class + [0.1]
|
| |
|
| | cfg.train_dataloader.batch_size = batch_size
|
| |
|
| |
|
| | cfg.train_cfg.max_iters = max_iters
|
| | cfg.train_cfg.val_interval = val_interval
|
| | cfg.default_hooks.logger.interval = 100
|
| | cfg.default_hooks.checkpoint.interval = 2500
|
| | cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
| | cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
| |
|
| | cfg['randomness'] = dict(seed=0)
|
| |
|
| | cfg.save_dir = save_dir
|
| | name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
| | cfg.work_dir = os.path.join(work_dir,name)
|
| | os.makedirs(cfg.work_dir, exist_ok=True)
|
| | save_config_file = os.path.join(save_dir, f"{name}.py")
|
| | cfg.dump(save_config_file)
|
| | print(f"Configuration saved to: {save_config_file}")
|
| |
|
| | def create_segformer_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
| | cfg = Config.fromfile(model_config_path)
|
| | dataset_cfg = Config.fromfile(dataset_config_path)
|
| | cfg.merge_from_dict(dataset_cfg)
|
| |
|
| |
|
| | cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
| | cfg.model.data_preprocessor.size = cfg.crop_size
|
| | cfg.model.decode_head.norm_cfg = cfg.norm_cfg
|
| |
|
| | cfg.model.decode_head.num_classes = num_class
|
| |
|
| | cfg.train_dataloader.batch_size = batch_size
|
| |
|
| |
|
| | cfg.train_cfg.max_iters = max_iters
|
| | cfg.train_cfg.val_interval = val_interval
|
| | cfg.default_hooks.logger.interval = 100
|
| | cfg.default_hooks.checkpoint.interval = 2500
|
| | cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
| | cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
| |
|
| | cfg['randomness'] = dict(seed=0)
|
| |
|
| | cfg.save_dir = save_dir
|
| | name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
| | cfg.work_dir = os.path.join(work_dir,name)
|
| | os.makedirs(cfg.work_dir, exist_ok=True)
|
| | save_config_file = os.path.join(save_dir, f"{name}.py")
|
| | cfg.dump(save_config_file)
|
| | print(f"Configuration saved to: {save_config_file}")
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description='Train configuration setup for different models.')
|
| |
|
| | parser.add_argument('--model_name', type=str, required=True, choices=['deeplabv3plus', 'knet', 'mask2former', 'segformer'],
|
| | help='Model name to generate the config for.')
|
| | parser.add_argument('-m', '--model_config', type=str, required=True, help="Path to the model config file")
|
| | parser.add_argument('-d', '--dataset_config', type=str, required=True, help='Path to the dataset config file.')
|
| | parser.add_argument('-c', '--num_class', type=int, required=True, help="Number of classes in the dataset")
|
| | parser.add_argument('-w','--work_dir', type=str, required=True, help='Directory to save the train result.')
|
| | parser.add_argument('-s', '--save_dir', type=str, required=True, help="Directory to save the generated config file")
|
| | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
| | parser.add_argument('--max_iters', type=int, default=20000, help='Number of training iterations.')
|
| | parser.add_argument('--val_interval', type=int, default=500, help='Interval for validation during training.')
|
| |
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | if args.model_name == 'deeplabv3plus':
|
| | create_deeplabv3plus_config(
|
| | model_config_path=args.model_config,
|
| | dataset_config_path=args.dataset_config,
|
| | num_class=args.num_class,
|
| | work_dir=args.work_dir,
|
| | save_dir =args.save_dir,
|
| | batch_size=args.batch_size,
|
| | max_iters=args.max_iters,
|
| | val_interval=args.val_interval
|
| | )
|
| | if args.model_name == 'knet':
|
| | create_knet_config(
|
| | model_config_path=args.model_config,
|
| | dataset_config_path=args.dataset_config,
|
| | num_class=args.num_class,
|
| | work_dir=args.work_dir,
|
| | save_dir =args.save_dir,
|
| | batch_size=args.batch_size,
|
| | max_iters=args.max_iters,
|
| | val_interval=args.val_interval
|
| | )
|
| | if args.model_name == 'mask2former':
|
| | create_mask2former_config(
|
| | model_config_path=args.model_config,
|
| | dataset_config_path=args.dataset_config,
|
| | num_class=args.num_class,
|
| | work_dir=args.work_dir,
|
| | save_dir =args.save_dir,
|
| | batch_size=args.batch_size,
|
| | max_iters=args.max_iters,
|
| | val_interval=args.val_interval
|
| | )
|
| | elif args.model_name == 'segformer':
|
| | create_segformer_config(
|
| | model_config_path=args.model_config,
|
| | dataset_config_path=args.dataset_config,
|
| | num_class=args.num_class,
|
| | work_dir=args.work_dir,
|
| | save_dir =args.save_dir,
|
| | batch_size=args.batch_size,
|
| | max_iters=args.max_iters,
|
| | val_interval=args.val_interval
|
| | )
|
| |
|
| | if __name__ == '__main__':
|
| | main()
|
| |
|