| |
|
|
| import logging |
| import os |
| from collections import OrderedDict |
| from typing import List, Optional, Union |
| import torch |
| from torch import nn |
|
|
| from detectron2.checkpoint import DetectionCheckpointer |
| from detectron2.config import CfgNode |
| from detectron2.engine import DefaultTrainer |
| from detectron2.evaluation import ( |
| DatasetEvaluator, |
| DatasetEvaluators, |
| inference_on_dataset, |
| print_csv_format, |
| ) |
| from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping |
| from detectron2.utils import comm |
| from detectron2.utils.events import EventWriter, get_event_storage |
|
|
| from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg |
| from densepose.data import ( |
| DatasetMapper, |
| build_combined_loader, |
| build_detection_test_loader, |
| build_detection_train_loader, |
| build_inference_based_loaders, |
| has_inference_based_loaders, |
| ) |
| from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter |
| from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage |
| from densepose.modeling.cse import Embedder |
|
|
|
|
| class SampleCountingLoader: |
| def __init__(self, loader): |
| self.loader = loader |
|
|
| def __iter__(self): |
| it = iter(self.loader) |
| storage = get_event_storage() |
| while True: |
| try: |
| batch = next(it) |
| num_inst_per_dataset = {} |
| for data in batch: |
| dataset_name = data["dataset"] |
| if dataset_name not in num_inst_per_dataset: |
| num_inst_per_dataset[dataset_name] = 0 |
| num_inst = len(data["instances"]) |
| num_inst_per_dataset[dataset_name] += num_inst |
| for dataset_name in num_inst_per_dataset: |
| storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name]) |
| yield batch |
| except StopIteration: |
| break |
|
|
|
|
| class SampleCountMetricPrinter(EventWriter): |
| def __init__(self): |
| self.logger = logging.getLogger(__name__) |
|
|
| def write(self): |
| storage = get_event_storage() |
| batch_stats_strs = [] |
| for key, buf in storage.histories().items(): |
| if key.startswith("batch/"): |
| batch_stats_strs.append(f"{key} {buf.avg(20)}") |
| self.logger.info(", ".join(batch_stats_strs)) |
|
|
|
|
| class Trainer(DefaultTrainer): |
| @classmethod |
| def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]: |
| if isinstance(model, nn.parallel.DistributedDataParallel): |
| model = model.module |
| if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"): |
| return model.roi_heads.embedder |
| return None |
|
|
| |
| |
| @classmethod |
| def test( |
| cls, |
| cfg: CfgNode, |
| model: nn.Module, |
| evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None, |
| ): |
| """ |
| Args: |
| cfg (CfgNode): |
| model (nn.Module): |
| evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call |
| :meth:`build_evaluator`. Otherwise, must have the same length as |
| ``cfg.DATASETS.TEST``. |
| |
| Returns: |
| dict: a dict of result metrics |
| """ |
| logger = logging.getLogger(__name__) |
| if isinstance(evaluators, DatasetEvaluator): |
| evaluators = [evaluators] |
| if evaluators is not None: |
| assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( |
| len(cfg.DATASETS.TEST), len(evaluators) |
| ) |
|
|
| results = OrderedDict() |
| for idx, dataset_name in enumerate(cfg.DATASETS.TEST): |
| data_loader = cls.build_test_loader(cfg, dataset_name) |
| |
| |
| if evaluators is not None: |
| evaluator = evaluators[idx] |
| else: |
| try: |
| embedder = cls.extract_embedder_from_model(model) |
| evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder) |
| except NotImplementedError: |
| logger.warn( |
| "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " |
| "or implement its `build_evaluator` method." |
| ) |
| results[dataset_name] = {} |
| continue |
| if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process(): |
| results_i = inference_on_dataset(model, data_loader, evaluator) |
| else: |
| results_i = {} |
| results[dataset_name] = results_i |
| if comm.is_main_process(): |
| assert isinstance( |
| results_i, dict |
| ), "Evaluator must return a dict on the main process. Got {} instead.".format( |
| results_i |
| ) |
| logger.info("Evaluation results for {} in csv format:".format(dataset_name)) |
| print_csv_format(results_i) |
|
|
| if len(results) == 1: |
| results = list(results.values())[0] |
| return results |
|
|
| @classmethod |
| def build_evaluator( |
| cls, |
| cfg: CfgNode, |
| dataset_name: str, |
| output_folder: Optional[str] = None, |
| embedder: Optional[Embedder] = None, |
| ) -> DatasetEvaluators: |
| if output_folder is None: |
| output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") |
| evaluators = [] |
| distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE |
| |
| |
| |
| |
| |
| |
| |
| |
| evaluators.append( |
| Detectron2COCOEvaluatorAdapter( |
| dataset_name, output_dir=output_folder, distributed=distributed |
| ) |
| ) |
| if cfg.MODEL.DENSEPOSE_ON: |
| storage = build_densepose_evaluator_storage(cfg, output_folder) |
| evaluators.append( |
| DensePoseCOCOEvaluator( |
| dataset_name, |
| distributed, |
| output_folder, |
| evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE, |
| min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD, |
| storage=storage, |
| embedder=embedder, |
| should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT, |
| mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES, |
| ) |
| ) |
| return DatasetEvaluators(evaluators) |
|
|
| @classmethod |
| def build_optimizer(cls, cfg: CfgNode, model: nn.Module): |
| params = get_default_optimizer_params( |
| model, |
| base_lr=cfg.SOLVER.BASE_LR, |
| weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, |
| bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, |
| weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, |
| overrides={ |
| "features": { |
| "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR, |
| }, |
| "embeddings": { |
| "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR, |
| }, |
| }, |
| ) |
| optimizer = torch.optim.SGD( |
| params, |
| cfg.SOLVER.BASE_LR, |
| momentum=cfg.SOLVER.MOMENTUM, |
| nesterov=cfg.SOLVER.NESTEROV, |
| weight_decay=cfg.SOLVER.WEIGHT_DECAY, |
| ) |
| |
| return maybe_add_gradient_clipping(cfg, optimizer) |
|
|
| @classmethod |
| def build_test_loader(cls, cfg: CfgNode, dataset_name): |
| return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False)) |
|
|
| @classmethod |
| def build_train_loader(cls, cfg: CfgNode): |
| data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True)) |
| if not has_inference_based_loaders(cfg): |
| return data_loader |
| model = cls.build_model(cfg) |
| model.to(cfg.BOOTSTRAP_MODEL.DEVICE) |
| DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False) |
| inference_based_loaders, ratios = build_inference_based_loaders(cfg, model) |
| loaders = [data_loader] + inference_based_loaders |
| ratios = [1.0] + ratios |
| combined_data_loader = build_combined_loader(cfg, loaders, ratios) |
| sample_counting_loader = SampleCountingLoader(combined_data_loader) |
| return sample_counting_loader |
|
|
| def build_writers(self): |
| writers = super().build_writers() |
| writers.append(SampleCountMetricPrinter()) |
| return writers |
|
|
| @classmethod |
| def test_with_TTA(cls, cfg: CfgNode, model): |
| logger = logging.getLogger("detectron2.trainer") |
| |
| |
| logger.info("Running inference with test-time augmentation ...") |
| transform_data = load_from_cfg(cfg) |
| model = DensePoseGeneralizedRCNNWithTTA( |
| cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg) |
| ) |
| evaluators = [ |
| cls.build_evaluator( |
| cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") |
| ) |
| for name in cfg.DATASETS.TEST |
| ] |
| res = cls.test(cfg, model, evaluators) |
| res = OrderedDict({k + "_TTA": v for k, v in res.items()}) |
| return res |
|
|