| |
| |
| |
| |
| |
| """ |
| OpenSeed Training Script based on MaskDINO. |
| """ |
| try: |
| from shapely.errors import ShapelyDeprecationWarning |
| import warnings |
| warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning) |
| except: |
| pass |
|
|
| import sys |
| import copy |
| import itertools |
| import logging |
| import os |
| import time |
|
|
| from collections import OrderedDict |
| from typing import Any, Dict, List, Set |
| from fvcore.nn.precise_bn import get_bn_modules |
|
|
| import torch |
|
|
| import detectron2.utils.comm as comm |
| from detectron2.checkpoint import DetectionCheckpointer |
| from detectron2.config import get_cfg, CfgNode |
| from detectron2.data import MetadataCatalog, build_detection_train_loader |
|
|
| from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler |
| from detectron2.solver.build import maybe_add_gradient_clipping |
| from detectron2.utils.logger import setup_logger |
| from detectron2.config import LazyConfig, instantiate |
|
|
| from utils.arguments import load_opt_command |
| from detectron2.utils.comm import get_world_size, is_main_process |
|
|
| |
|
|
| from datasets import ( |
| build_train_dataloader, |
| build_evaluator, |
| build_eval_dataloader, |
| ) |
| import random |
| from detectron2.engine import ( |
| DefaultTrainer, |
| default_argument_parser, |
| default_setup, |
| hooks, |
| launch, |
| create_ddp_model, |
| AMPTrainer, |
| SimpleTrainer |
| ) |
| import weakref |
|
|
| from openseed import build_model |
| from openseed.BaseModel import BaseModel |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level = logging.INFO) |
|
|
|
|
| class Trainer(DefaultTrainer): |
| """ |
| Extension of the Trainer class adapted to MaskFormer. |
| """ |
| def __init__(self, cfg): |
| super(DefaultTrainer, self).__init__() |
| logger = logging.getLogger("detectron2") |
| if not logger.isEnabledFor(logging.INFO): |
| setup_logger() |
| cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) |
|
|
| |
| model = self.build_model(cfg) |
| optimizer = self.build_optimizer(cfg, model) |
| data_loader = self.build_train_loader(cfg) |
|
|
| model = create_ddp_model(model, broadcast_buffers=False) |
| self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( |
| model, data_loader, optimizer |
| ) |
| self.scheduler = self.build_lr_scheduler(cfg, optimizer) |
|
|
| |
| kwargs = { |
| 'trainer': weakref.proxy(self), |
| } |
| |
| self.checkpointer = DetectionCheckpointer( |
| |
| model, |
| cfg['OUTPUT_DIR'], |
| **kwargs, |
| ) |
| self.start_iter = 0 |
| self.max_iter = cfg['SOLVER']['MAX_ITER'] |
| self.cfg = cfg |
|
|
| self.register_hooks(self.build_hooks()) |
| |
| self.checkpointer = DetectionCheckpointer( |
| |
| model, |
| cfg['OUTPUT_DIR'], |
| **kwargs, |
| ) |
| |
|
|
| def build_hooks(self): |
| """ |
| Build a list of default hooks, including timing, evaluation, |
| checkpointing, lr scheduling, precise BN, writing events. |
| |
| Returns: |
| list[HookBase]: |
| """ |
| cfg = copy.deepcopy(self.cfg) |
| |
| cfg.DATALOADER.NUM_WORKERS = 0 |
| ret = [ |
| hooks.IterationTimer(), |
| hooks.LRScheduler(), |
| None, |
| ] |
|
|
| |
| |
| |
| |
| if comm.is_main_process(): |
| ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) |
|
|
| def test_and_save_results(): |
| self._last_eval_results = self.test(self.cfg, self.model) |
| return self._last_eval_results |
|
|
| |
| |
| ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) |
|
|
| if comm.is_main_process(): |
| |
| |
| ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) |
| return ret |
|
|
| @classmethod |
| def build_model(cls, cfg): |
| """ |
| Returns: |
| torch.nn.Module: |
| |
| It now calls :func:`detectron2.modeling.build_model`. |
| Overwrite it if you'd like a different model. |
| """ |
| model = BaseModel(cfg, build_model(cfg)).cuda() |
| logger = logging.getLogger(__name__) |
| logger.info("Model:\n{}".format(model)) |
| return model |
|
|
| @classmethod |
| def build_evaluator(cls, cfg, dataset_name, output_folder=None): |
| return build_evaluator(cfg, dataset_name, output_folder=output_folder) |
|
|
| @classmethod |
| def build_train_loader(cls, cfg): |
| return build_train_dataloader(cfg, ) |
|
|
| @classmethod |
| def build_test_loader(cls, cfg, dataset_name): |
| loader = build_eval_dataloader(cfg, ) |
| return loader |
|
|
| @classmethod |
| def build_lr_scheduler(cls, cfg, optimizer): |
| """ |
| It now calls :func:`detectron2.solver.build_lr_scheduler`. |
| Overwrite it if you'd like a different scheduler. |
| """ |
| return build_lr_scheduler(cfg, optimizer) |
|
|
| @classmethod |
| def build_optimizer(cls, cfg, model): |
| cfg_solver = cfg['SOLVER'] |
| weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM'] |
| weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED'] |
| weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0) |
|
|
| defaults = {} |
| defaults["lr"] = cfg_solver['BASE_LR'] |
| defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY'] |
|
|
| norm_module_types = ( |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.SyncBatchNorm, |
| |
| torch.nn.GroupNorm, |
| torch.nn.InstanceNorm1d, |
| torch.nn.InstanceNorm2d, |
| torch.nn.InstanceNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.LocalResponseNorm, |
| ) |
|
|
| lr_multiplier = cfg['SOLVER']['LR_MULTIPLIER'] |
|
|
| params: List[Dict[str, Any]] = [] |
| memo: Set[torch.nn.parameter.Parameter] = set() |
| for module_name, module in model.named_modules(): |
| for module_param_name, value in module.named_parameters(recurse=False): |
| if not value.requires_grad: |
| continue |
| |
| if value in memo: |
| continue |
| memo.add(value) |
|
|
| hyperparams = copy.copy(defaults) |
|
|
| for key, lr_mul in lr_multiplier.items(): |
| if key in "{}.{}".format(module_name, module_param_name): |
| hyperparams["lr"] = hyperparams["lr"] * lr_mul |
| if is_main_process(): |
| logger.info("Modify Learning rate of {}: {}".format( |
| "{}.{}".format(module_name, module_param_name), lr_mul)) |
|
|
| if ( |
| "relative_position_bias_table" in module_param_name |
| or "absolute_pos_embed" in module_param_name |
| ): |
| hyperparams["weight_decay"] = 0.0 |
| if isinstance(module, norm_module_types): |
| hyperparams["weight_decay"] = weight_decay_norm |
| if isinstance(module, torch.nn.Embedding): |
| hyperparams["weight_decay"] = weight_decay_embed |
| if "bias" in module_name: |
| hyperparams["weight_decay"] = weight_decay_bias |
| params.append({"params": [value], **hyperparams}) |
|
|
| def maybe_add_full_model_gradient_clipping(optim): |
| |
| clip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE'] |
| enable = ( |
| cfg_solver['CLIP_GRADIENTS']['ENABLED'] |
| and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model" |
| and clip_norm_val > 0.0 |
| ) |
|
|
| class FullModelGradientClippingOptimizer(optim): |
| def step(self, closure=None): |
| all_params = itertools.chain(*[x["params"] for x in self.param_groups]) |
| torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) |
| super().step(closure=closure) |
|
|
| return FullModelGradientClippingOptimizer if enable else optim |
|
|
| optimizer_type = cfg_solver['OPTIMIZER'] |
| if optimizer_type == "SGD": |
| optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( |
| params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM'] |
| ) |
| elif optimizer_type == "ADAMW": |
| optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( |
| params, cfg_solver['BASE_LR'] |
| ) |
| else: |
| raise NotImplementedError(f"no optimizer type {optimizer_type}") |
| return optimizer |
|
|
| @staticmethod |
| def auto_scale_workers(cfg, num_workers: int): |
| """ |
| Returns: |
| CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``. |
| """ |
| old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE |
| if old_world_size == 0 or old_world_size == num_workers: |
| return cfg |
| cfg = copy.deepcopy(cfg) |
| |
| |
|
|
| assert ( |
| cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0 |
| ), "Invalid REFERENCE_WORLD_SIZE in config!" |
| scale = num_workers / old_world_size |
| bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale)) |
| lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale |
| max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale)) |
| warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale)) |
| cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS) |
| cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale)) |
| cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale)) |
| cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers |
| logger = logging.getLogger(__name__) |
| logger.info( |
| f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " |
| f"max_iter={max_iter}, warmup={warmup_iter}." |
| ) |
| return cfg |
|
|
| @classmethod |
| def test(cls, cfg, model, evaluators=None): |
| from utils.misc import hook_metadata, hook_switcher, hook_opt |
| from openseed.utils import get_class_names |
| from detectron2.utils.logger import log_every_n_seconds |
| import datetime |
| |
| dataloaders = cls.build_test_loader(cfg, dataset_name=None) |
| dataset_names = cfg['DATASETS']['TEST'] |
| model = model.eval().cuda() |
| model_without_ddp = model |
| if not type(model) == BaseModel: |
| model_without_ddp = model.module |
|
|
| for dataloader, dataset_name in zip(dataloaders, dataset_names): |
| |
| evaluator = build_evaluator(cfg, dataset_name, cfg['OUTPUT_DIR']) |
| evaluator.reset() |
| with torch.no_grad(): |
| |
| names = get_class_names(dataset_name, cfg['MODEL'].get('BACKGROUND', True)) |
| |
| model_without_ddp.model.metadata = MetadataCatalog.get(dataset_name) |
| eval_type = model_without_ddp.model.metadata.evaluator_type |
| if 'background' in names: |
| model_without_ddp.model.sem_seg_head.num_classes = len(names) - 1 |
| else: |
| model_without_ddp.model.sem_seg_head.num_classes = len(names) |
| model_without_ddp.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(names, is_eval=True) |
| hook_switcher(model_without_ddp, dataset_name) |
| |
|
|
| |
| task = 'seg' |
|
|
| |
| total = len(dataloader) |
| num_warmup = min(5, total - 1) |
| start_time = time.perf_counter() |
| total_data_time = 0 |
| total_compute_time = 0 |
| total_eval_time = 0 |
| start_data_time = time.perf_counter() |
|
|
| for idx, batch in enumerate(dataloader): |
| total_data_time += time.perf_counter() - start_data_time |
| if idx == num_warmup: |
| start_time = time.perf_counter() |
| total_data_time = 0 |
| total_compute_time = 0 |
| total_eval_time = 0 |
| start_compute_time = time.perf_counter() |
|
|
| |
| with torch.autocast(device_type='cuda', dtype=torch.float16): |
| |
| outputs = model(batch, inference_task=task) |
|
|
| total_compute_time += time.perf_counter() - start_compute_time |
| start_eval_time = time.perf_counter() |
|
|
| evaluator.process(batch, outputs) |
| total_eval_time += time.perf_counter() - start_eval_time |
|
|
| iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) |
| data_seconds_per_iter = total_data_time / iters_after_start |
| compute_seconds_per_iter = total_compute_time / iters_after_start |
| eval_seconds_per_iter = total_eval_time / iters_after_start |
| total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start |
|
|
| if is_main_process() and (idx >= num_warmup * 2 or compute_seconds_per_iter > 5): |
| eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1))) |
| log_every_n_seconds( |
| logging.INFO, |
| ( |
| f"Inference done {idx + 1}/{total}. " |
| f"Dataloading: {data_seconds_per_iter:.4f} s/iter. " |
| f"Inference: {compute_seconds_per_iter:.4f} s/iter. " |
| f"Eval: {eval_seconds_per_iter:.4f} s/iter. " |
| f"Total: {total_seconds_per_iter:.4f} s/iter. " |
| f"ETA={eta}" |
| ), |
| n=5, |
| ) |
| start_data_time = time.perf_counter() |
|
|
| |
| results = evaluator.evaluate() |
|
|
| model = model.train().cuda() |
|
|
|
|
| def setup(args): |
| """ |
| Create configs and perform basic setups. |
| """ |
| cfg = get_cfg() |
| cfg = LazyConfig.load(args.config_file) |
| cfg = LazyConfig.apply_overrides(cfg, args.opts) |
| default_setup(cfg, args) |
| setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") |
| return cfg |
|
|
|
|
| def main(args=None): |
| cfg = setup(args) |
| print("Command cfg:", cfg) |
| if args.eval_only: |
| model = Trainer.build_model(cfg) |
| DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( |
| cfg.MODEL.WEIGHTS, resume=args.resume |
| ) |
| if args.original_load: |
| print("using original loading") |
| model = model.from_pretrained(cfg.MODEL.WEIGHTS) |
| res = Trainer.test(cfg, model) |
| if cfg.TEST.AUG.ENABLED: |
| res.update(Trainer.test_with_TTA(cfg, model)) |
|
|
| return res |
|
|
| trainer = Trainer(cfg) |
| if len(args.lang_weight) > 0: |
| |
| import copy |
| weight = copy.deepcopy(trainer.cfg.MODEL.WEIGHTS) |
| trainer.cfg.MODEL.WEIGHTS = args.lang_weight |
| print("load original language language weight!!!!!!") |
| |
| trainer._trainer.model.module = trainer._trainer.model.module.from_pretrained(cfg.MODEL.WEIGHTS) |
| trainer.cfg.MODEL.WEIGHTS = weight |
| print("load pretrained model weight!!!!!!") |
| trainer.resume_or_load(resume=args.resume) |
| if args.original_load: |
| print("using original loading") |
| try: |
| trainer._trainer.model.module = trainer._trainer.model.module.from_pretrained(cfg.MODEL.WEIGHTS) |
| except Exception as e: |
| trainer._trainer.model = trainer._trainer.model.from_pretrained(cfg.MODEL.WEIGHTS) |
| return trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| parser = default_argument_parser() |
| parser.add_argument('--eval_only', action='store_true') |
| parser.add_argument('--original_load', action='store_true') |
| parser.add_argument('--lang_weight', type=str, default='') |
| parser.add_argument('--EVAL_FLAG', type=int, default=1) |
| args = parser.parse_args() |
| port = random.randint(1000, 20000) |
| args.dist_url = 'tcp://127.0.0.1:' + str(port) |
| print("Command Line Args:", args) |
| print("pwd:", os.getcwd()) |
| launch( |
| main, |
| args.num_gpus, |
| num_machines=args.num_machines, |
| machine_rank=args.machine_rank, |
| dist_url=args.dist_url, |
| args=(args,), |
| ) |
|
|