|
|
|
|
| import logging
|
| import os
|
| from collections import OrderedDict
|
| from typing import List, Optional, Union
|
| import torch
|
| from torch import nn
|
|
|
| from detectron2.checkpoint import DetectionCheckpointer
|
| from detectron2.config import CfgNode
|
| from detectron2.engine import DefaultTrainer
|
| from detectron2.evaluation import (
|
| DatasetEvaluator,
|
| DatasetEvaluators,
|
| inference_on_dataset,
|
| print_csv_format,
|
| )
|
| from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
|
| from detectron2.utils import comm
|
| from detectron2.utils.events import EventWriter, get_event_storage
|
|
|
| from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
|
| from densepose.data import (
|
| DatasetMapper,
|
| build_combined_loader,
|
| build_detection_test_loader,
|
| build_detection_train_loader,
|
| build_inference_based_loaders,
|
| has_inference_based_loaders,
|
| )
|
| from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
|
| from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
|
| from densepose.modeling.cse import Embedder
|
|
|
|
|
| class SampleCountingLoader:
|
| def __init__(self, loader):
|
| self.loader = loader
|
|
|
| def __iter__(self):
|
| it = iter(self.loader)
|
| storage = get_event_storage()
|
| while True:
|
| try:
|
| batch = next(it)
|
| num_inst_per_dataset = {}
|
| for data in batch:
|
| dataset_name = data["dataset"]
|
| if dataset_name not in num_inst_per_dataset:
|
| num_inst_per_dataset[dataset_name] = 0
|
| num_inst = len(data["instances"])
|
| num_inst_per_dataset[dataset_name] += num_inst
|
| for dataset_name in num_inst_per_dataset:
|
| storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
|
| yield batch
|
| except StopIteration:
|
| break
|
|
|
|
|
| class SampleCountMetricPrinter(EventWriter):
|
| def __init__(self):
|
| self.logger = logging.getLogger(__name__)
|
|
|
| def write(self):
|
| storage = get_event_storage()
|
| batch_stats_strs = []
|
| for key, buf in storage.histories().items():
|
| if key.startswith("batch/"):
|
| batch_stats_strs.append(f"{key} {buf.avg(20)}")
|
| self.logger.info(", ".join(batch_stats_strs))
|
|
|
|
|
| class Trainer(DefaultTrainer):
|
| @classmethod
|
| def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
|
| if isinstance(model, nn.parallel.DistributedDataParallel):
|
| model = model.module
|
| if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
|
| return model.roi_heads.embedder
|
| return None
|
|
|
|
|
|
|
| @classmethod
|
| def test(
|
| cls,
|
| cfg: CfgNode,
|
| model: nn.Module,
|
| evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
|
| ):
|
| """
|
| Args:
|
| cfg (CfgNode):
|
| model (nn.Module):
|
| evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
|
| :meth:`build_evaluator`. Otherwise, must have the same length as
|
| ``cfg.DATASETS.TEST``.
|
|
|
| Returns:
|
| dict: a dict of result metrics
|
| """
|
| logger = logging.getLogger(__name__)
|
| if isinstance(evaluators, DatasetEvaluator):
|
| evaluators = [evaluators]
|
| if evaluators is not None:
|
| assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
| len(cfg.DATASETS.TEST), len(evaluators)
|
| )
|
|
|
| results = OrderedDict()
|
| for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
| data_loader = cls.build_test_loader(cfg, dataset_name)
|
|
|
|
|
| if evaluators is not None:
|
| evaluator = evaluators[idx]
|
| else:
|
| try:
|
| embedder = cls.extract_embedder_from_model(model)
|
| evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
|
| except NotImplementedError:
|
| logger.warn(
|
| "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
| "or implement its `build_evaluator` method."
|
| )
|
| results[dataset_name] = {}
|
| continue
|
| if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
|
| results_i = inference_on_dataset(model, data_loader, evaluator)
|
| else:
|
| results_i = {}
|
| results[dataset_name] = results_i
|
| if comm.is_main_process():
|
| assert isinstance(
|
| results_i, dict
|
| ), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
| results_i
|
| )
|
| logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
| print_csv_format(results_i)
|
|
|
| if len(results) == 1:
|
| results = list(results.values())[0]
|
| return results
|
|
|
| @classmethod
|
| def build_evaluator(
|
| cls,
|
| cfg: CfgNode,
|
| dataset_name: str,
|
| output_folder: Optional[str] = None,
|
| embedder: Optional[Embedder] = None,
|
| ) -> DatasetEvaluators:
|
| if output_folder is None:
|
| output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
| evaluators = []
|
| distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| evaluators.append(
|
| Detectron2COCOEvaluatorAdapter(
|
| dataset_name, output_dir=output_folder, distributed=distributed
|
| )
|
| )
|
| if cfg.MODEL.DENSEPOSE_ON:
|
| storage = build_densepose_evaluator_storage(cfg, output_folder)
|
| evaluators.append(
|
| DensePoseCOCOEvaluator(
|
| dataset_name,
|
| distributed,
|
| output_folder,
|
| evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
|
| min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
|
| storage=storage,
|
| embedder=embedder,
|
| should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
|
| mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
|
| )
|
| )
|
| return DatasetEvaluators(evaluators)
|
|
|
| @classmethod
|
| def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
|
| params = get_default_optimizer_params(
|
| model,
|
| base_lr=cfg.SOLVER.BASE_LR,
|
| weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
| bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
| weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
| overrides={
|
| "features": {
|
| "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
|
| },
|
| "embeddings": {
|
| "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
|
| },
|
| },
|
| )
|
| optimizer = torch.optim.SGD(
|
| params,
|
| cfg.SOLVER.BASE_LR,
|
| momentum=cfg.SOLVER.MOMENTUM,
|
| nesterov=cfg.SOLVER.NESTEROV,
|
| weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
| )
|
|
|
| return maybe_add_gradient_clipping(cfg, optimizer)
|
|
|
| @classmethod
|
| def build_test_loader(cls, cfg: CfgNode, dataset_name):
|
| return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
|
|
|
| @classmethod
|
| def build_train_loader(cls, cfg: CfgNode):
|
| data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
|
| if not has_inference_based_loaders(cfg):
|
| return data_loader
|
| model = cls.build_model(cfg)
|
| model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
|
| DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
|
| inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
|
| loaders = [data_loader] + inference_based_loaders
|
| ratios = [1.0] + ratios
|
| combined_data_loader = build_combined_loader(cfg, loaders, ratios)
|
| sample_counting_loader = SampleCountingLoader(combined_data_loader)
|
| return sample_counting_loader
|
|
|
| def build_writers(self):
|
| writers = super().build_writers()
|
| writers.append(SampleCountMetricPrinter())
|
| return writers
|
|
|
| @classmethod
|
| def test_with_TTA(cls, cfg: CfgNode, model):
|
| logger = logging.getLogger("detectron2.trainer")
|
|
|
|
|
| logger.info("Running inference with test-time augmentation ...")
|
| transform_data = load_from_cfg(cfg)
|
| model = DensePoseGeneralizedRCNNWithTTA(
|
| cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
|
| )
|
| evaluators = [
|
| cls.build_evaluator(
|
| cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
| )
|
| for name in cfg.DATASETS.TEST
|
| ]
|
| res = cls.test(cfg, model, evaluators)
|
| res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
| return res
|
|
|