Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- CatVTON/densepose/data/datasets/__init__.py +7 -0
- CatVTON/densepose/data/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/builtin.cpython-39.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/chimpnsee.cpython-39.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/coco.cpython-39.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/dataset_type.cpython-39.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/lvis.cpython-39.pyc +0 -0
- CatVTON/densepose/data/datasets/builtin.py +18 -0
- CatVTON/densepose/data/datasets/chimpnsee.py +31 -0
- CatVTON/densepose/data/datasets/coco.py +434 -0
- CatVTON/densepose/data/datasets/dataset_type.py +13 -0
- CatVTON/densepose/data/datasets/lvis.py +259 -0
- CatVTON/densepose/data/samplers/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/densepose_base.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/densepose_confidence_based.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/densepose_cse_base.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/densepose_cse_confidence_based.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/densepose_cse_uniform.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/densepose_uniform.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/mask_from_densepose.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/__pycache__/prediction_to_gt.cpython-39.pyc +0 -0
- CatVTON/densepose/data/samplers/densepose_base.py +205 -0
- CatVTON/densepose/data/samplers/densepose_confidence_based.py +110 -0
- CatVTON/densepose/data/samplers/densepose_cse_uniform.py +14 -0
- CatVTON/densepose/data/samplers/mask_from_densepose.py +30 -0
- CatVTON/densepose/data/samplers/prediction_to_gt.py +100 -0
- CatVTON/densepose/data/transform/__init__.py +5 -0
- CatVTON/densepose/data/transform/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/densepose/data/transform/__pycache__/image.cpython-39.pyc +0 -0
- CatVTON/densepose/data/transform/image.py +41 -0
- CatVTON/detectron2/__init__.py +10 -0
- CatVTON/detectron2/checkpoint/__init__.py +10 -0
- CatVTON/detectron2/checkpoint/c2_model_loading.py +406 -0
- CatVTON/detectron2/checkpoint/catalog.py +115 -0
- CatVTON/detectron2/checkpoint/detection_checkpoint.py +143 -0
- CatVTON/detectron2/engine/__init__.py +19 -0
- CatVTON/detectron2/engine/defaults.py +719 -0
- CatVTON/detectron2/engine/hooks.py +690 -0
- CatVTON/detectron2/engine/launch.py +123 -0
- CatVTON/detectron2/engine/train_loop.py +530 -0
- CatVTON/detectron2/modeling/__init__.py +64 -0
- CatVTON/detectron2/modeling/anchor_generator.py +390 -0
- CatVTON/detectron2/modeling/box_regression.py +369 -0
- CatVTON/detectron2/modeling/matcher.py +127 -0
- CatVTON/detectron2/modeling/poolers.py +263 -0
- CatVTON/detectron2/projects/README.md +2 -0
- CatVTON/detectron2/projects/__init__.py +34 -0
- CatVTON/detectron2/solver/__init__.py +11 -0
- CatVTON/detectron2/solver/build.py +323 -0
- CatVTON/detectron2/solver/lr_scheduler.py +247 -0
CatVTON/densepose/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from . import builtin # ensure the builtin datasets are registered
|
| 6 |
+
|
| 7 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
CatVTON/densepose/data/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (384 Bytes). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/builtin.cpython-39.pyc
ADDED
|
Binary file (575 Bytes). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/chimpnsee.cpython-39.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/coco.cpython-39.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/dataset_type.cpython-39.pyc
ADDED
|
Binary file (499 Bytes). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/lvis.cpython-39.pyc
ADDED
|
Binary file (7.83 kB). View file
|
|
|
CatVTON/densepose/data/datasets/builtin.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from .chimpnsee import register_dataset as register_chimpnsee_dataset
|
| 5 |
+
from .coco import BASE_DATASETS as BASE_COCO_DATASETS
|
| 6 |
+
from .coco import DATASETS as COCO_DATASETS
|
| 7 |
+
from .coco import register_datasets as register_coco_datasets
|
| 8 |
+
from .lvis import DATASETS as LVIS_DATASETS
|
| 9 |
+
from .lvis import register_datasets as register_lvis_datasets
|
| 10 |
+
|
| 11 |
+
DEFAULT_DATASETS_ROOT = "datasets"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
|
| 15 |
+
register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
|
| 16 |
+
register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
|
| 17 |
+
|
| 18 |
+
register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
|
CatVTON/densepose/data/datasets/chimpnsee.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 8 |
+
|
| 9 |
+
from ..utils import maybe_prepend_base_path
|
| 10 |
+
from .dataset_type import DatasetType
|
| 11 |
+
|
| 12 |
+
CHIMPNSEE_DATASET_NAME = "chimpnsee"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def register_dataset(datasets_root: Optional[str] = None) -> None:
|
| 16 |
+
def empty_load_callback():
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
video_list_fpath = maybe_prepend_base_path(
|
| 20 |
+
datasets_root,
|
| 21 |
+
"chimpnsee/cdna.eva.mpg.de/video_list.txt",
|
| 22 |
+
)
|
| 23 |
+
video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de")
|
| 24 |
+
|
| 25 |
+
DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback)
|
| 26 |
+
MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set(
|
| 27 |
+
dataset_type=DatasetType.VIDEO_LIST,
|
| 28 |
+
video_list_fpath=video_list_fpath,
|
| 29 |
+
video_base_path=video_base_path,
|
| 30 |
+
category="chimpanzee",
|
| 31 |
+
)
|
CatVTON/densepose/data/datasets/coco.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
import contextlib
|
| 5 |
+
import io
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 11 |
+
from fvcore.common.timer import Timer
|
| 12 |
+
|
| 13 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 14 |
+
from detectron2.structures import BoxMode
|
| 15 |
+
from detectron2.utils.file_io import PathManager
|
| 16 |
+
|
| 17 |
+
from ..utils import maybe_prepend_base_path
|
| 18 |
+
|
| 19 |
+
DENSEPOSE_MASK_KEY = "dp_masks"
|
| 20 |
+
DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"]
|
| 21 |
+
DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"]
|
| 22 |
+
DENSEPOSE_ALL_POSSIBLE_KEYS = set(
|
| 23 |
+
DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY]
|
| 24 |
+
)
|
| 25 |
+
DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class CocoDatasetInfo:
|
| 30 |
+
name: str
|
| 31 |
+
images_root: str
|
| 32 |
+
annotations_fpath: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DATASETS = [
|
| 36 |
+
CocoDatasetInfo(
|
| 37 |
+
name="densepose_coco_2014_train",
|
| 38 |
+
images_root="coco/train2014",
|
| 39 |
+
annotations_fpath="coco/annotations/densepose_train2014.json",
|
| 40 |
+
),
|
| 41 |
+
CocoDatasetInfo(
|
| 42 |
+
name="densepose_coco_2014_minival",
|
| 43 |
+
images_root="coco/val2014",
|
| 44 |
+
annotations_fpath="coco/annotations/densepose_minival2014.json",
|
| 45 |
+
),
|
| 46 |
+
CocoDatasetInfo(
|
| 47 |
+
name="densepose_coco_2014_minival_100",
|
| 48 |
+
images_root="coco/val2014",
|
| 49 |
+
annotations_fpath="coco/annotations/densepose_minival2014_100.json",
|
| 50 |
+
),
|
| 51 |
+
CocoDatasetInfo(
|
| 52 |
+
name="densepose_coco_2014_valminusminival",
|
| 53 |
+
images_root="coco/val2014",
|
| 54 |
+
annotations_fpath="coco/annotations/densepose_valminusminival2014.json",
|
| 55 |
+
),
|
| 56 |
+
CocoDatasetInfo(
|
| 57 |
+
name="densepose_coco_2014_train_cse",
|
| 58 |
+
images_root="coco/train2014",
|
| 59 |
+
annotations_fpath="coco_cse/densepose_train2014_cse.json",
|
| 60 |
+
),
|
| 61 |
+
CocoDatasetInfo(
|
| 62 |
+
name="densepose_coco_2014_minival_cse",
|
| 63 |
+
images_root="coco/val2014",
|
| 64 |
+
annotations_fpath="coco_cse/densepose_minival2014_cse.json",
|
| 65 |
+
),
|
| 66 |
+
CocoDatasetInfo(
|
| 67 |
+
name="densepose_coco_2014_minival_100_cse",
|
| 68 |
+
images_root="coco/val2014",
|
| 69 |
+
annotations_fpath="coco_cse/densepose_minival2014_100_cse.json",
|
| 70 |
+
),
|
| 71 |
+
CocoDatasetInfo(
|
| 72 |
+
name="densepose_coco_2014_valminusminival_cse",
|
| 73 |
+
images_root="coco/val2014",
|
| 74 |
+
annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json",
|
| 75 |
+
),
|
| 76 |
+
CocoDatasetInfo(
|
| 77 |
+
name="densepose_chimps",
|
| 78 |
+
images_root="densepose_chimps/images",
|
| 79 |
+
annotations_fpath="densepose_chimps/densepose_chimps_densepose.json",
|
| 80 |
+
),
|
| 81 |
+
CocoDatasetInfo(
|
| 82 |
+
name="densepose_chimps_cse_train",
|
| 83 |
+
images_root="densepose_chimps/images",
|
| 84 |
+
annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json",
|
| 85 |
+
),
|
| 86 |
+
CocoDatasetInfo(
|
| 87 |
+
name="densepose_chimps_cse_val",
|
| 88 |
+
images_root="densepose_chimps/images",
|
| 89 |
+
annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json",
|
| 90 |
+
),
|
| 91 |
+
CocoDatasetInfo(
|
| 92 |
+
name="posetrack2017_train",
|
| 93 |
+
images_root="posetrack2017/posetrack_data_2017",
|
| 94 |
+
annotations_fpath="posetrack2017/densepose_posetrack_train2017.json",
|
| 95 |
+
),
|
| 96 |
+
CocoDatasetInfo(
|
| 97 |
+
name="posetrack2017_val",
|
| 98 |
+
images_root="posetrack2017/posetrack_data_2017",
|
| 99 |
+
annotations_fpath="posetrack2017/densepose_posetrack_val2017.json",
|
| 100 |
+
),
|
| 101 |
+
CocoDatasetInfo(
|
| 102 |
+
name="lvis_v05_train",
|
| 103 |
+
images_root="coco/train2017",
|
| 104 |
+
annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json",
|
| 105 |
+
),
|
| 106 |
+
CocoDatasetInfo(
|
| 107 |
+
name="lvis_v05_val",
|
| 108 |
+
images_root="coco/val2017",
|
| 109 |
+
annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json",
|
| 110 |
+
),
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
BASE_DATASETS = [
|
| 115 |
+
CocoDatasetInfo(
|
| 116 |
+
name="base_coco_2017_train",
|
| 117 |
+
images_root="coco/train2017",
|
| 118 |
+
annotations_fpath="coco/annotations/instances_train2017.json",
|
| 119 |
+
),
|
| 120 |
+
CocoDatasetInfo(
|
| 121 |
+
name="base_coco_2017_val",
|
| 122 |
+
images_root="coco/val2017",
|
| 123 |
+
annotations_fpath="coco/annotations/instances_val2017.json",
|
| 124 |
+
),
|
| 125 |
+
CocoDatasetInfo(
|
| 126 |
+
name="base_coco_2017_val_100",
|
| 127 |
+
images_root="coco/val2017",
|
| 128 |
+
annotations_fpath="coco/annotations/instances_val2017_100.json",
|
| 129 |
+
),
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_metadata(base_path: Optional[str]) -> Dict[str, Any]:
|
| 134 |
+
"""
|
| 135 |
+
Returns metadata associated with COCO DensePose datasets
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
base_path: Optional[str]
|
| 139 |
+
Base path used to load metadata from
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Dict[str, Any]
|
| 143 |
+
Metadata in the form of a dictionary
|
| 144 |
+
"""
|
| 145 |
+
meta = {
|
| 146 |
+
"densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"),
|
| 147 |
+
"densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"),
|
| 148 |
+
"densepose_smpl_subdiv_transform": maybe_prepend_base_path(
|
| 149 |
+
base_path,
|
| 150 |
+
"SMPL_SUBDIV_TRANSFORM.mat",
|
| 151 |
+
),
|
| 152 |
+
}
|
| 153 |
+
return meta
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _load_coco_annotations(json_file: str):
|
| 157 |
+
"""
|
| 158 |
+
Load COCO annotations from a JSON file
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
json_file: str
|
| 162 |
+
Path to the file to load annotations from
|
| 163 |
+
Returns:
|
| 164 |
+
Instance of `pycocotools.coco.COCO` that provides access to annotations
|
| 165 |
+
data
|
| 166 |
+
"""
|
| 167 |
+
from pycocotools.coco import COCO
|
| 168 |
+
|
| 169 |
+
logger = logging.getLogger(__name__)
|
| 170 |
+
timer = Timer()
|
| 171 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 172 |
+
coco_api = COCO(json_file)
|
| 173 |
+
if timer.seconds() > 1:
|
| 174 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 175 |
+
return coco_api
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]):
|
| 179 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 180 |
+
meta.categories = {c["id"]: c["name"] for c in categories}
|
| 181 |
+
logger = logging.getLogger(__name__)
|
| 182 |
+
logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]):
|
| 186 |
+
if "minival" in json_file:
|
| 187 |
+
# Skip validation on COCO2014 valminusminival and minival annotations
|
| 188 |
+
# The ratio of buggy annotations there is tiny and does not affect accuracy
|
| 189 |
+
# Therefore we explicitly white-list them
|
| 190 |
+
return
|
| 191 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 192 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
| 193 |
+
json_file
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
| 198 |
+
if "bbox" not in ann_dict:
|
| 199 |
+
return
|
| 200 |
+
obj["bbox"] = ann_dict["bbox"]
|
| 201 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
| 205 |
+
if "segmentation" not in ann_dict:
|
| 206 |
+
return
|
| 207 |
+
segm = ann_dict["segmentation"]
|
| 208 |
+
if not isinstance(segm, dict):
|
| 209 |
+
# filter out invalid polygons (< 3 points)
|
| 210 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 211 |
+
if len(segm) == 0:
|
| 212 |
+
return
|
| 213 |
+
obj["segmentation"] = segm
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
| 217 |
+
if "keypoints" not in ann_dict:
|
| 218 |
+
return
|
| 219 |
+
keypts = ann_dict["keypoints"] # list[int]
|
| 220 |
+
for idx, v in enumerate(keypts):
|
| 221 |
+
if idx % 3 != 2:
|
| 222 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 223 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 224 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
| 225 |
+
# add 0.5 to convert to floating point coordinates.
|
| 226 |
+
keypts[idx] = v + 0.5
|
| 227 |
+
obj["keypoints"] = keypts
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
|
| 231 |
+
for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
|
| 232 |
+
if key in ann_dict:
|
| 233 |
+
obj[key] = ann_dict[key]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _combine_images_with_annotations(
|
| 237 |
+
dataset_name: str,
|
| 238 |
+
image_root: str,
|
| 239 |
+
img_datas: Iterable[Dict[str, Any]],
|
| 240 |
+
ann_datas: Iterable[Iterable[Dict[str, Any]]],
|
| 241 |
+
):
|
| 242 |
+
|
| 243 |
+
ann_keys = ["iscrowd", "category_id"]
|
| 244 |
+
dataset_dicts = []
|
| 245 |
+
contains_video_frame_info = False
|
| 246 |
+
|
| 247 |
+
for img_dict, ann_dicts in zip(img_datas, ann_datas):
|
| 248 |
+
record = {}
|
| 249 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
| 250 |
+
record["height"] = img_dict["height"]
|
| 251 |
+
record["width"] = img_dict["width"]
|
| 252 |
+
record["image_id"] = img_dict["id"]
|
| 253 |
+
record["dataset"] = dataset_name
|
| 254 |
+
if "frame_id" in img_dict:
|
| 255 |
+
record["frame_id"] = img_dict["frame_id"]
|
| 256 |
+
record["video_id"] = img_dict.get("vid_id", None)
|
| 257 |
+
contains_video_frame_info = True
|
| 258 |
+
objs = []
|
| 259 |
+
for ann_dict in ann_dicts:
|
| 260 |
+
assert ann_dict["image_id"] == record["image_id"]
|
| 261 |
+
assert ann_dict.get("ignore", 0) == 0
|
| 262 |
+
obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict}
|
| 263 |
+
_maybe_add_bbox(obj, ann_dict)
|
| 264 |
+
_maybe_add_segm(obj, ann_dict)
|
| 265 |
+
_maybe_add_keypoints(obj, ann_dict)
|
| 266 |
+
_maybe_add_densepose(obj, ann_dict)
|
| 267 |
+
objs.append(obj)
|
| 268 |
+
record["annotations"] = objs
|
| 269 |
+
dataset_dicts.append(record)
|
| 270 |
+
if contains_video_frame_info:
|
| 271 |
+
create_video_frame_mapping(dataset_name, dataset_dicts)
|
| 272 |
+
return dataset_dicts
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def get_contiguous_id_to_category_id_map(metadata):
|
| 276 |
+
cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id
|
| 277 |
+
cont_id_2_cat_id = {}
|
| 278 |
+
for cat_id, cont_id in cat_id_2_cont_id.items():
|
| 279 |
+
if cont_id in cont_id_2_cat_id:
|
| 280 |
+
continue
|
| 281 |
+
cont_id_2_cat_id[cont_id] = cat_id
|
| 282 |
+
return cont_id_2_cat_id
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def maybe_filter_categories_cocoapi(dataset_name, coco_api):
|
| 286 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 287 |
+
cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta)
|
| 288 |
+
cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id
|
| 289 |
+
# filter categories
|
| 290 |
+
cats = []
|
| 291 |
+
for cat in coco_api.dataset["categories"]:
|
| 292 |
+
cat_id = cat["id"]
|
| 293 |
+
if cat_id not in cat_id_2_cont_id:
|
| 294 |
+
continue
|
| 295 |
+
cont_id = cat_id_2_cont_id[cat_id]
|
| 296 |
+
if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
|
| 297 |
+
cats.append(cat)
|
| 298 |
+
coco_api.dataset["categories"] = cats
|
| 299 |
+
# filter annotations, if multiple categories are mapped to a single
|
| 300 |
+
# contiguous ID, use only one category ID and map all annotations to that category ID
|
| 301 |
+
anns = []
|
| 302 |
+
for ann in coco_api.dataset["annotations"]:
|
| 303 |
+
cat_id = ann["category_id"]
|
| 304 |
+
if cat_id not in cat_id_2_cont_id:
|
| 305 |
+
continue
|
| 306 |
+
cont_id = cat_id_2_cont_id[cat_id]
|
| 307 |
+
ann["category_id"] = cont_id_2_cat_id[cont_id]
|
| 308 |
+
anns.append(ann)
|
| 309 |
+
coco_api.dataset["annotations"] = anns
|
| 310 |
+
# recreate index
|
| 311 |
+
coco_api.createIndex()
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api):
|
| 315 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 316 |
+
category_id_map = meta.thing_dataset_id_to_contiguous_id
|
| 317 |
+
# map categories
|
| 318 |
+
cats = []
|
| 319 |
+
for cat in coco_api.dataset["categories"]:
|
| 320 |
+
cat_id = cat["id"]
|
| 321 |
+
if cat_id not in category_id_map:
|
| 322 |
+
continue
|
| 323 |
+
cat["id"] = category_id_map[cat_id]
|
| 324 |
+
cats.append(cat)
|
| 325 |
+
coco_api.dataset["categories"] = cats
|
| 326 |
+
# map annotation categories
|
| 327 |
+
anns = []
|
| 328 |
+
for ann in coco_api.dataset["annotations"]:
|
| 329 |
+
cat_id = ann["category_id"]
|
| 330 |
+
if cat_id not in category_id_map:
|
| 331 |
+
continue
|
| 332 |
+
ann["category_id"] = category_id_map[cat_id]
|
| 333 |
+
anns.append(ann)
|
| 334 |
+
coco_api.dataset["annotations"] = anns
|
| 335 |
+
# recreate index
|
| 336 |
+
coco_api.createIndex()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def create_video_frame_mapping(dataset_name, dataset_dicts):
|
| 340 |
+
mapping = defaultdict(dict)
|
| 341 |
+
for d in dataset_dicts:
|
| 342 |
+
video_id = d.get("video_id")
|
| 343 |
+
if video_id is None:
|
| 344 |
+
continue
|
| 345 |
+
mapping[video_id].update({d["frame_id"]: d["file_name"]})
|
| 346 |
+
MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str):
|
| 350 |
+
"""
|
| 351 |
+
Loads a JSON file with annotations in COCO instances format.
|
| 352 |
+
Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata
|
| 353 |
+
in a more flexible way. Postpones category mapping to a later stage to be
|
| 354 |
+
able to combine several datasets with different (but coherent) sets of
|
| 355 |
+
categories.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
|
| 359 |
+
annotations_json_file: str
|
| 360 |
+
Path to the JSON file with annotations in COCO instances format.
|
| 361 |
+
image_root: str
|
| 362 |
+
directory that contains all the images
|
| 363 |
+
dataset_name: str
|
| 364 |
+
the name that identifies a dataset, e.g. "densepose_coco_2014_train"
|
| 365 |
+
extra_annotation_keys: Optional[List[str]]
|
| 366 |
+
If provided, these keys are used to extract additional data from
|
| 367 |
+
the annotations.
|
| 368 |
+
"""
|
| 369 |
+
coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file))
|
| 370 |
+
_add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds()))
|
| 371 |
+
# sort indices for reproducible results
|
| 372 |
+
img_ids = sorted(coco_api.imgs.keys())
|
| 373 |
+
# imgs is a list of dicts, each looks something like:
|
| 374 |
+
# {'license': 4,
|
| 375 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 376 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 377 |
+
# 'height': 427,
|
| 378 |
+
# 'width': 640,
|
| 379 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 380 |
+
# 'id': 1268}
|
| 381 |
+
imgs = coco_api.loadImgs(img_ids)
|
| 382 |
+
logger = logging.getLogger(__name__)
|
| 383 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file))
|
| 384 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 385 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 386 |
+
# and the outer list enumerates over images.
|
| 387 |
+
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
| 388 |
+
_verify_annotations_have_unique_ids(annotations_json_file, anns)
|
| 389 |
+
dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
|
| 390 |
+
return dataset_records
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None):
|
| 394 |
+
"""
|
| 395 |
+
Registers provided COCO DensePose dataset
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
dataset_data: CocoDatasetInfo
|
| 399 |
+
Dataset data
|
| 400 |
+
datasets_root: Optional[str]
|
| 401 |
+
Datasets root folder (default: None)
|
| 402 |
+
"""
|
| 403 |
+
annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
|
| 404 |
+
images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
|
| 405 |
+
|
| 406 |
+
def load_annotations():
|
| 407 |
+
return load_coco_json(
|
| 408 |
+
annotations_json_file=annotations_fpath,
|
| 409 |
+
image_root=images_root,
|
| 410 |
+
dataset_name=dataset_data.name,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
DatasetCatalog.register(dataset_data.name, load_annotations)
|
| 414 |
+
MetadataCatalog.get(dataset_data.name).set(
|
| 415 |
+
json_file=annotations_fpath,
|
| 416 |
+
image_root=images_root,
|
| 417 |
+
**get_metadata(DENSEPOSE_METADATA_URL_PREFIX)
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def register_datasets(
|
| 422 |
+
datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
|
| 423 |
+
):
|
| 424 |
+
"""
|
| 425 |
+
Registers provided COCO DensePose datasets
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
datasets_data: Iterable[CocoDatasetInfo]
|
| 429 |
+
An iterable of dataset datas
|
| 430 |
+
datasets_root: Optional[str]
|
| 431 |
+
Datasets root folder (default: None)
|
| 432 |
+
"""
|
| 433 |
+
for dataset_data in datasets_data:
|
| 434 |
+
register_dataset(dataset_data, datasets_root)
|
CatVTON/densepose/data/datasets/dataset_type.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DatasetType(Enum):
|
| 9 |
+
"""
|
| 10 |
+
Dataset type, mostly used for datasets that contain data to bootstrap models on
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
VIDEO_LIST = "video_list"
|
CatVTON/densepose/data/datasets/lvis.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 7 |
+
from fvcore.common.timer import Timer
|
| 8 |
+
|
| 9 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 10 |
+
from detectron2.data.datasets.lvis import get_lvis_instances_meta
|
| 11 |
+
from detectron2.structures import BoxMode
|
| 12 |
+
from detectron2.utils.file_io import PathManager
|
| 13 |
+
|
| 14 |
+
from ..utils import maybe_prepend_base_path
|
| 15 |
+
from .coco import (
|
| 16 |
+
DENSEPOSE_ALL_POSSIBLE_KEYS,
|
| 17 |
+
DENSEPOSE_METADATA_URL_PREFIX,
|
| 18 |
+
CocoDatasetInfo,
|
| 19 |
+
get_metadata,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
DATASETS = [
|
| 23 |
+
CocoDatasetInfo(
|
| 24 |
+
name="densepose_lvis_v1_ds1_train_v1",
|
| 25 |
+
images_root="coco_",
|
| 26 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json",
|
| 27 |
+
),
|
| 28 |
+
CocoDatasetInfo(
|
| 29 |
+
name="densepose_lvis_v1_ds1_val_v1",
|
| 30 |
+
images_root="coco_",
|
| 31 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json",
|
| 32 |
+
),
|
| 33 |
+
CocoDatasetInfo(
|
| 34 |
+
name="densepose_lvis_v1_ds2_train_v1",
|
| 35 |
+
images_root="coco_",
|
| 36 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json",
|
| 37 |
+
),
|
| 38 |
+
CocoDatasetInfo(
|
| 39 |
+
name="densepose_lvis_v1_ds2_val_v1",
|
| 40 |
+
images_root="coco_",
|
| 41 |
+
annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json",
|
| 42 |
+
),
|
| 43 |
+
CocoDatasetInfo(
|
| 44 |
+
name="densepose_lvis_v1_ds1_val_animals_100",
|
| 45 |
+
images_root="coco_",
|
| 46 |
+
annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json",
|
| 47 |
+
),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _load_lvis_annotations(json_file: str):
|
| 52 |
+
"""
|
| 53 |
+
Load COCO annotations from a JSON file
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
json_file: str
|
| 57 |
+
Path to the file to load annotations from
|
| 58 |
+
Returns:
|
| 59 |
+
Instance of `pycocotools.coco.COCO` that provides access to annotations
|
| 60 |
+
data
|
| 61 |
+
"""
|
| 62 |
+
from lvis import LVIS
|
| 63 |
+
|
| 64 |
+
json_file = PathManager.get_local_path(json_file)
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
timer = Timer()
|
| 67 |
+
lvis_api = LVIS(json_file)
|
| 68 |
+
if timer.seconds() > 1:
|
| 69 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 70 |
+
return lvis_api
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _add_categories_metadata(dataset_name: str) -> None:
|
| 74 |
+
metadict = get_lvis_instances_meta(dataset_name)
|
| 75 |
+
categories = metadict["thing_classes"]
|
| 76 |
+
metadata = MetadataCatalog.get(dataset_name)
|
| 77 |
+
metadata.categories = {i + 1: categories[i] for i in range(len(categories))}
|
| 78 |
+
logger = logging.getLogger(__name__)
|
| 79 |
+
logger.info(f"Dataset {dataset_name} has {len(categories)} categories")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None:
|
| 83 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 84 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
| 85 |
+
json_file
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
| 90 |
+
if "bbox" not in ann_dict:
|
| 91 |
+
return
|
| 92 |
+
obj["bbox"] = ann_dict["bbox"]
|
| 93 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
| 97 |
+
if "segmentation" not in ann_dict:
|
| 98 |
+
return
|
| 99 |
+
segm = ann_dict["segmentation"]
|
| 100 |
+
if not isinstance(segm, dict):
|
| 101 |
+
# filter out invalid polygons (< 3 points)
|
| 102 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 103 |
+
if len(segm) == 0:
|
| 104 |
+
return
|
| 105 |
+
obj["segmentation"] = segm
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
| 109 |
+
if "keypoints" not in ann_dict:
|
| 110 |
+
return
|
| 111 |
+
keypts = ann_dict["keypoints"] # list[int]
|
| 112 |
+
for idx, v in enumerate(keypts):
|
| 113 |
+
if idx % 3 != 2:
|
| 114 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 115 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 116 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
| 117 |
+
# add 0.5 to convert to floating point coordinates.
|
| 118 |
+
keypts[idx] = v + 0.5
|
| 119 |
+
obj["keypoints"] = keypts
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
|
| 123 |
+
for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
|
| 124 |
+
if key in ann_dict:
|
| 125 |
+
obj[key] = ann_dict[key]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _combine_images_with_annotations(
|
| 129 |
+
dataset_name: str,
|
| 130 |
+
image_root: str,
|
| 131 |
+
img_datas: Iterable[Dict[str, Any]],
|
| 132 |
+
ann_datas: Iterable[Iterable[Dict[str, Any]]],
|
| 133 |
+
):
|
| 134 |
+
|
| 135 |
+
dataset_dicts = []
|
| 136 |
+
|
| 137 |
+
def get_file_name(img_root, img_dict):
|
| 138 |
+
# Determine the path including the split folder ("train2017", "val2017", "test2017") from
|
| 139 |
+
# the coco_url field. Example:
|
| 140 |
+
# 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
|
| 141 |
+
split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
|
| 142 |
+
return os.path.join(img_root + split_folder, file_name)
|
| 143 |
+
|
| 144 |
+
for img_dict, ann_dicts in zip(img_datas, ann_datas):
|
| 145 |
+
record = {}
|
| 146 |
+
record["file_name"] = get_file_name(image_root, img_dict)
|
| 147 |
+
record["height"] = img_dict["height"]
|
| 148 |
+
record["width"] = img_dict["width"]
|
| 149 |
+
record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
|
| 150 |
+
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
|
| 151 |
+
record["image_id"] = img_dict["id"]
|
| 152 |
+
record["dataset"] = dataset_name
|
| 153 |
+
|
| 154 |
+
objs = []
|
| 155 |
+
for ann_dict in ann_dicts:
|
| 156 |
+
assert ann_dict["image_id"] == record["image_id"]
|
| 157 |
+
obj = {}
|
| 158 |
+
_maybe_add_bbox(obj, ann_dict)
|
| 159 |
+
obj["iscrowd"] = ann_dict.get("iscrowd", 0)
|
| 160 |
+
obj["category_id"] = ann_dict["category_id"]
|
| 161 |
+
_maybe_add_segm(obj, ann_dict)
|
| 162 |
+
_maybe_add_keypoints(obj, ann_dict)
|
| 163 |
+
_maybe_add_densepose(obj, ann_dict)
|
| 164 |
+
objs.append(obj)
|
| 165 |
+
record["annotations"] = objs
|
| 166 |
+
dataset_dicts.append(record)
|
| 167 |
+
return dataset_dicts
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str):
|
| 171 |
+
"""
|
| 172 |
+
Loads a JSON file with annotations in LVIS instances format.
|
| 173 |
+
Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata
|
| 174 |
+
in a more flexible way. Postpones category mapping to a later stage to be
|
| 175 |
+
able to combine several datasets with different (but coherent) sets of
|
| 176 |
+
categories.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
|
| 180 |
+
annotations_json_file: str
|
| 181 |
+
Path to the JSON file with annotations in COCO instances format.
|
| 182 |
+
image_root: str
|
| 183 |
+
directory that contains all the images
|
| 184 |
+
dataset_name: str
|
| 185 |
+
the name that identifies a dataset, e.g. "densepose_coco_2014_train"
|
| 186 |
+
extra_annotation_keys: Optional[List[str]]
|
| 187 |
+
If provided, these keys are used to extract additional data from
|
| 188 |
+
the annotations.
|
| 189 |
+
"""
|
| 190 |
+
lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file))
|
| 191 |
+
|
| 192 |
+
_add_categories_metadata(dataset_name)
|
| 193 |
+
|
| 194 |
+
# sort indices for reproducible results
|
| 195 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
| 196 |
+
# imgs is a list of dicts, each looks something like:
|
| 197 |
+
# {'license': 4,
|
| 198 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 199 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 200 |
+
# 'height': 427,
|
| 201 |
+
# 'width': 640,
|
| 202 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 203 |
+
# 'id': 1268}
|
| 204 |
+
imgs = lvis_api.load_imgs(img_ids)
|
| 205 |
+
logger = logging.getLogger(__name__)
|
| 206 |
+
logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file))
|
| 207 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 208 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 209 |
+
# and the outer list enumerates over images.
|
| 210 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
| 211 |
+
|
| 212 |
+
_verify_annotations_have_unique_ids(annotations_json_file, anns)
|
| 213 |
+
dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
|
| 214 |
+
return dataset_records
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None:
|
| 218 |
+
"""
|
| 219 |
+
Registers provided LVIS DensePose dataset
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
dataset_data: CocoDatasetInfo
|
| 223 |
+
Dataset data
|
| 224 |
+
datasets_root: Optional[str]
|
| 225 |
+
Datasets root folder (default: None)
|
| 226 |
+
"""
|
| 227 |
+
annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
|
| 228 |
+
images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
|
| 229 |
+
|
| 230 |
+
def load_annotations():
|
| 231 |
+
return load_lvis_json(
|
| 232 |
+
annotations_json_file=annotations_fpath,
|
| 233 |
+
image_root=images_root,
|
| 234 |
+
dataset_name=dataset_data.name,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
DatasetCatalog.register(dataset_data.name, load_annotations)
|
| 238 |
+
MetadataCatalog.get(dataset_data.name).set(
|
| 239 |
+
json_file=annotations_fpath,
|
| 240 |
+
image_root=images_root,
|
| 241 |
+
evaluator_type="lvis",
|
| 242 |
+
**get_metadata(DENSEPOSE_METADATA_URL_PREFIX),
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def register_datasets(
|
| 247 |
+
datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
|
| 248 |
+
) -> None:
|
| 249 |
+
"""
|
| 250 |
+
Registers provided LVIS DensePose datasets
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
datasets_data: Iterable[CocoDatasetInfo]
|
| 254 |
+
An iterable of dataset datas
|
| 255 |
+
datasets_root: Optional[str]
|
| 256 |
+
Datasets root folder (default: None)
|
| 257 |
+
"""
|
| 258 |
+
for dataset_data in datasets_data:
|
| 259 |
+
register_dataset(dataset_data, datasets_root)
|
CatVTON/densepose/data/samplers/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (615 Bytes). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/densepose_base.cpython-39.pyc
ADDED
|
Binary file (6.17 kB). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/densepose_confidence_based.cpython-39.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/densepose_cse_base.cpython-39.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/densepose_cse_confidence_based.cpython-39.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/densepose_cse_uniform.cpython-39.pyc
ADDED
|
Binary file (539 Bytes). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/densepose_uniform.cpython-39.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/mask_from_densepose.cpython-39.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
CatVTON/densepose/data/samplers/__pycache__/prediction_to_gt.cpython-39.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
CatVTON/densepose/data/samplers/densepose_base.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures import BoxMode, Instances
|
| 10 |
+
|
| 11 |
+
from densepose.converters import ToChartResultConverter
|
| 12 |
+
from densepose.converters.base import IntTupleBox, make_int_box
|
| 13 |
+
from densepose.structures import DensePoseDataRelative, DensePoseList
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DensePoseBaseSampler:
|
| 17 |
+
"""
|
| 18 |
+
Base DensePose sampler to produce DensePose data from DensePose predictions.
|
| 19 |
+
Samples for each class are drawn according to some distribution over all pixels estimated
|
| 20 |
+
to belong to that class.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, count_per_class: int = 8):
|
| 24 |
+
"""
|
| 25 |
+
Constructor
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 29 |
+
samples for each category
|
| 30 |
+
"""
|
| 31 |
+
self.count_per_class = count_per_class
|
| 32 |
+
|
| 33 |
+
def __call__(self, instances: Instances) -> DensePoseList:
|
| 34 |
+
"""
|
| 35 |
+
Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
|
| 36 |
+
into DensePose annotations data (an instance of `DensePoseList`)
|
| 37 |
+
"""
|
| 38 |
+
boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
|
| 39 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 40 |
+
dp_datas = []
|
| 41 |
+
for i in range(len(boxes_xywh_abs)):
|
| 42 |
+
annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
|
| 43 |
+
annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
|
| 44 |
+
instances[i].pred_densepose
|
| 45 |
+
)
|
| 46 |
+
dp_datas.append(DensePoseDataRelative(annotation_i))
|
| 47 |
+
# create densepose annotations on CPU
|
| 48 |
+
dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
|
| 49 |
+
return dp_list
|
| 50 |
+
|
| 51 |
+
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
| 52 |
+
"""
|
| 53 |
+
Sample DensPoseDataRelative from estimation results
|
| 54 |
+
"""
|
| 55 |
+
labels, dp_result = self._produce_labels_and_results(instance)
|
| 56 |
+
annotation = {
|
| 57 |
+
DensePoseDataRelative.X_KEY: [],
|
| 58 |
+
DensePoseDataRelative.Y_KEY: [],
|
| 59 |
+
DensePoseDataRelative.U_KEY: [],
|
| 60 |
+
DensePoseDataRelative.V_KEY: [],
|
| 61 |
+
DensePoseDataRelative.I_KEY: [],
|
| 62 |
+
}
|
| 63 |
+
n, h, w = dp_result.shape
|
| 64 |
+
for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
|
| 65 |
+
# indices - tuple of 3 1D tensors of size k
|
| 66 |
+
# 0: index along the first dimension N
|
| 67 |
+
# 1: index along H dimension
|
| 68 |
+
# 2: index along W dimension
|
| 69 |
+
indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
|
| 70 |
+
# values - an array of size [n, k]
|
| 71 |
+
# n: number of channels (U, V, confidences)
|
| 72 |
+
# k: number of points labeled with part_id
|
| 73 |
+
values = dp_result[indices].view(n, -1)
|
| 74 |
+
k = values.shape[1]
|
| 75 |
+
count = min(self.count_per_class, k)
|
| 76 |
+
if count <= 0:
|
| 77 |
+
continue
|
| 78 |
+
index_sample = self._produce_index_sample(values, count)
|
| 79 |
+
sampled_values = values[:, index_sample]
|
| 80 |
+
sampled_y = indices[1][index_sample] + 0.5
|
| 81 |
+
sampled_x = indices[2][index_sample] + 0.5
|
| 82 |
+
# prepare / normalize data
|
| 83 |
+
x = (sampled_x / w * 256.0).cpu().tolist()
|
| 84 |
+
y = (sampled_y / h * 256.0).cpu().tolist()
|
| 85 |
+
u = sampled_values[0].clamp(0, 1).cpu().tolist()
|
| 86 |
+
v = sampled_values[1].clamp(0, 1).cpu().tolist()
|
| 87 |
+
fine_segm_labels = [part_id] * count
|
| 88 |
+
# extend annotations
|
| 89 |
+
annotation[DensePoseDataRelative.X_KEY].extend(x)
|
| 90 |
+
annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
| 91 |
+
annotation[DensePoseDataRelative.U_KEY].extend(u)
|
| 92 |
+
annotation[DensePoseDataRelative.V_KEY].extend(v)
|
| 93 |
+
annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
|
| 94 |
+
return annotation
|
| 95 |
+
|
| 96 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| 97 |
+
"""
|
| 98 |
+
Abstract method to produce a sample of indices to select data
|
| 99 |
+
To be implemented in descendants
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
| 103 |
+
estimated values (U, V, confidences);
|
| 104 |
+
n: number of channels (U, V, confidences)
|
| 105 |
+
k: number of points labeled with part_id
|
| 106 |
+
count (int): number of samples to produce, should be positive and <= k
|
| 107 |
+
|
| 108 |
+
Return:
|
| 109 |
+
list(int): indices of values (along axis 1) selected as a sample
|
| 110 |
+
"""
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 114 |
+
"""
|
| 115 |
+
Method to get labels and DensePose results from an instance
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
instance (Instances): an instance of `DensePoseChartPredictorOutput`
|
| 119 |
+
|
| 120 |
+
Return:
|
| 121 |
+
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
| 122 |
+
dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
|
| 123 |
+
"""
|
| 124 |
+
converter = ToChartResultConverter
|
| 125 |
+
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
| 126 |
+
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
| 127 |
+
return labels, dp_result
|
| 128 |
+
|
| 129 |
+
def _resample_mask(self, output: Any) -> torch.Tensor:
|
| 130 |
+
"""
|
| 131 |
+
Convert DensePose predictor output to segmentation annotation - tensors of size
|
| 132 |
+
(256, 256) and type `int64`.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
output: DensePose predictor output with the following attributes:
|
| 136 |
+
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
| 137 |
+
segmentation scores
|
| 138 |
+
- fine_segm: tensor of size [N, C, H, W] with unnormalized fine
|
| 139 |
+
segmentation scores
|
| 140 |
+
Return:
|
| 141 |
+
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
| 142 |
+
where S = DensePoseDataRelative.MASK_SIZE
|
| 143 |
+
"""
|
| 144 |
+
sz = DensePoseDataRelative.MASK_SIZE
|
| 145 |
+
S = (
|
| 146 |
+
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
| 147 |
+
.argmax(dim=1)
|
| 148 |
+
.long()
|
| 149 |
+
)
|
| 150 |
+
I = (
|
| 151 |
+
(
|
| 152 |
+
F.interpolate(
|
| 153 |
+
output.fine_segm,
|
| 154 |
+
(sz, sz),
|
| 155 |
+
mode="bilinear",
|
| 156 |
+
align_corners=False,
|
| 157 |
+
).argmax(dim=1)
|
| 158 |
+
* (S > 0).long()
|
| 159 |
+
)
|
| 160 |
+
.squeeze()
|
| 161 |
+
.cpu()
|
| 162 |
+
)
|
| 163 |
+
# Map fine segmentation results to coarse segmentation ground truth
|
| 164 |
+
# TODO: extract this into separate classes
|
| 165 |
+
# coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
|
| 166 |
+
# 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
|
| 167 |
+
# 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
|
| 168 |
+
# 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
|
| 169 |
+
# 14 = Head
|
| 170 |
+
# fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
|
| 171 |
+
# 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
|
| 172 |
+
# 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
|
| 173 |
+
# 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
|
| 174 |
+
# 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
|
| 175 |
+
# 20, 22 = Lower Arm Right, 23, 24 = Head
|
| 176 |
+
FINE_TO_COARSE_SEGMENTATION = {
|
| 177 |
+
1: 1,
|
| 178 |
+
2: 1,
|
| 179 |
+
3: 2,
|
| 180 |
+
4: 3,
|
| 181 |
+
5: 4,
|
| 182 |
+
6: 5,
|
| 183 |
+
7: 6,
|
| 184 |
+
8: 7,
|
| 185 |
+
9: 6,
|
| 186 |
+
10: 7,
|
| 187 |
+
11: 8,
|
| 188 |
+
12: 9,
|
| 189 |
+
13: 8,
|
| 190 |
+
14: 9,
|
| 191 |
+
15: 10,
|
| 192 |
+
16: 11,
|
| 193 |
+
17: 10,
|
| 194 |
+
18: 11,
|
| 195 |
+
19: 12,
|
| 196 |
+
20: 13,
|
| 197 |
+
21: 12,
|
| 198 |
+
22: 13,
|
| 199 |
+
23: 14,
|
| 200 |
+
24: 14,
|
| 201 |
+
}
|
| 202 |
+
mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
|
| 203 |
+
for i in range(DensePoseDataRelative.N_PART_LABELS):
|
| 204 |
+
mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
|
| 205 |
+
return mask
|
CatVTON/densepose/data/samplers/densepose_confidence_based.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from densepose.converters import ToChartResultConverterWithConfidences
|
| 10 |
+
|
| 11 |
+
from .densepose_base import DensePoseBaseSampler
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
|
| 15 |
+
"""
|
| 16 |
+
Samples DensePose data from DensePose predictions.
|
| 17 |
+
Samples for each class are drawn using confidence value estimates.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
confidence_channel: str,
|
| 23 |
+
count_per_class: int = 8,
|
| 24 |
+
search_count_multiplier: Optional[float] = None,
|
| 25 |
+
search_proportion: Optional[float] = None,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Constructor
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
confidence_channel (str): confidence channel to use for sampling;
|
| 32 |
+
possible values:
|
| 33 |
+
"sigma_2": confidences for UV values
|
| 34 |
+
"fine_segm_confidence": confidences for fine segmentation
|
| 35 |
+
"coarse_segm_confidence": confidences for coarse segmentation
|
| 36 |
+
(default: "sigma_2")
|
| 37 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 38 |
+
samples for each category (default: 8)
|
| 39 |
+
search_count_multiplier (float or None): if not None, the total number
|
| 40 |
+
of the most confident estimates of a given class to consider is
|
| 41 |
+
defined as `min(search_count_multiplier * count_per_class, N)`,
|
| 42 |
+
where `N` is the total number of estimates of the class; cannot be
|
| 43 |
+
specified together with `search_proportion` (default: None)
|
| 44 |
+
search_proportion (float or None): if not None, the total number of the
|
| 45 |
+
of the most confident estimates of a given class to consider is
|
| 46 |
+
defined as `min(max(search_proportion * N, count_per_class), N)`,
|
| 47 |
+
where `N` is the total number of estimates of the class; cannot be
|
| 48 |
+
specified together with `search_count_multiplier` (default: None)
|
| 49 |
+
"""
|
| 50 |
+
super().__init__(count_per_class)
|
| 51 |
+
self.confidence_channel = confidence_channel
|
| 52 |
+
self.search_count_multiplier = search_count_multiplier
|
| 53 |
+
self.search_proportion = search_proportion
|
| 54 |
+
assert (search_count_multiplier is None) or (search_proportion is None), (
|
| 55 |
+
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
|
| 56 |
+
f"and search_proportion (={search_proportion})"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| 60 |
+
"""
|
| 61 |
+
Produce a sample of indices to select data based on confidences
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
| 65 |
+
estimated values (U, V, confidences);
|
| 66 |
+
n: number of channels (U, V, confidences)
|
| 67 |
+
k: number of points labeled with part_id
|
| 68 |
+
count (int): number of samples to produce, should be positive and <= k
|
| 69 |
+
|
| 70 |
+
Return:
|
| 71 |
+
list(int): indices of values (along axis 1) selected as a sample
|
| 72 |
+
"""
|
| 73 |
+
k = values.shape[1]
|
| 74 |
+
if k == count:
|
| 75 |
+
index_sample = list(range(k))
|
| 76 |
+
else:
|
| 77 |
+
# take the best count * search_count_multiplier pixels,
|
| 78 |
+
# sample from them uniformly
|
| 79 |
+
# (here best = smallest variance)
|
| 80 |
+
_, sorted_confidence_indices = torch.sort(values[2])
|
| 81 |
+
if self.search_count_multiplier is not None:
|
| 82 |
+
search_count = min(int(count * self.search_count_multiplier), k)
|
| 83 |
+
elif self.search_proportion is not None:
|
| 84 |
+
search_count = min(max(int(k * self.search_proportion), count), k)
|
| 85 |
+
else:
|
| 86 |
+
search_count = min(count, k)
|
| 87 |
+
sample_from_top = random.sample(range(search_count), count)
|
| 88 |
+
index_sample = sorted_confidence_indices[:search_count][sample_from_top]
|
| 89 |
+
return index_sample
|
| 90 |
+
|
| 91 |
+
def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
Method to get labels and DensePose results from an instance, with confidences
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
|
| 97 |
+
|
| 98 |
+
Return:
|
| 99 |
+
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
| 100 |
+
dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
|
| 101 |
+
stacked with the confidence channel
|
| 102 |
+
"""
|
| 103 |
+
converter = ToChartResultConverterWithConfidences
|
| 104 |
+
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
| 105 |
+
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
| 106 |
+
dp_result = torch.cat(
|
| 107 |
+
(dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return labels, dp_result
|
CatVTON/densepose/data/samplers/densepose_cse_uniform.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .densepose_cse_base import DensePoseCSEBaseSampler
|
| 6 |
+
from .densepose_uniform import DensePoseUniformSampler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
|
| 10 |
+
"""
|
| 11 |
+
Uniform Sampler for CSE
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
pass
|
CatVTON/densepose/data/samplers/mask_from_densepose.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from detectron2.structures import BitMasks, Instances
|
| 6 |
+
|
| 7 |
+
from densepose.converters import ToMaskConverter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MaskFromDensePoseSampler:
|
| 11 |
+
"""
|
| 12 |
+
Produce mask GT from DensePose predictions
|
| 13 |
+
This sampler simply converts DensePose predictions to BitMasks
|
| 14 |
+
that a contain a bool tensor of the size of the input image
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __call__(self, instances: Instances) -> BitMasks:
|
| 18 |
+
"""
|
| 19 |
+
Converts predicted data from `instances` into the GT mask data
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
instances (Instances): predicted results, expected to have `pred_densepose` field
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Boolean Tensor of the size of the input image that has non-zero
|
| 26 |
+
values at pixels that are estimated to belong to the detected object
|
| 27 |
+
"""
|
| 28 |
+
return ToMaskConverter.convert(
|
| 29 |
+
instances.pred_densepose, instances.pred_boxes, instances.image_size
|
| 30 |
+
)
|
CatVTON/densepose/data/samplers/prediction_to_gt.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from detectron2.structures import Instances
|
| 9 |
+
|
| 10 |
+
ModelOutput = Dict[str, Any]
|
| 11 |
+
SampledData = Dict[str, Any]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class _Sampler:
|
| 16 |
+
"""
|
| 17 |
+
Sampler registry entry that contains:
|
| 18 |
+
- src (str): source field to sample from (deleted after sampling)
|
| 19 |
+
- dst (Optional[str]): destination field to sample to, if not None
|
| 20 |
+
- func (Optional[Callable: Any -> Any]): function that performs sampling,
|
| 21 |
+
if None, reference copy is performed
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
src: str
|
| 25 |
+
dst: Optional[str]
|
| 26 |
+
func: Optional[Callable[[Any], Any]]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PredictionToGroundTruthSampler:
|
| 30 |
+
"""
|
| 31 |
+
Sampler implementation that converts predictions to GT using registered
|
| 32 |
+
samplers for different fields of `Instances`.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, dataset_name: str = ""):
|
| 36 |
+
self.dataset_name = dataset_name
|
| 37 |
+
self._samplers = {}
|
| 38 |
+
self.register_sampler("pred_boxes", "gt_boxes", None)
|
| 39 |
+
self.register_sampler("pred_classes", "gt_classes", None)
|
| 40 |
+
# delete scores
|
| 41 |
+
self.register_sampler("scores")
|
| 42 |
+
|
| 43 |
+
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
|
| 44 |
+
"""
|
| 45 |
+
Transform model output into ground truth data through sampling
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model_output (Dict[str, Any]): model output
|
| 49 |
+
Returns:
|
| 50 |
+
Dict[str, Any]: sampled data
|
| 51 |
+
"""
|
| 52 |
+
for model_output_i in model_output:
|
| 53 |
+
instances: Instances = model_output_i["instances"]
|
| 54 |
+
# transform data in each field
|
| 55 |
+
for _, sampler in self._samplers.items():
|
| 56 |
+
if not instances.has(sampler.src) or sampler.dst is None:
|
| 57 |
+
continue
|
| 58 |
+
if sampler.func is None:
|
| 59 |
+
instances.set(sampler.dst, instances.get(sampler.src))
|
| 60 |
+
else:
|
| 61 |
+
instances.set(sampler.dst, sampler.func(instances))
|
| 62 |
+
# delete model output data that was transformed
|
| 63 |
+
for _, sampler in self._samplers.items():
|
| 64 |
+
if sampler.src != sampler.dst and instances.has(sampler.src):
|
| 65 |
+
instances.remove(sampler.src)
|
| 66 |
+
model_output_i["dataset"] = self.dataset_name
|
| 67 |
+
return model_output
|
| 68 |
+
|
| 69 |
+
def register_sampler(
|
| 70 |
+
self,
|
| 71 |
+
prediction_attr: str,
|
| 72 |
+
gt_attr: Optional[str] = None,
|
| 73 |
+
func: Optional[Callable[[Any], Any]] = None,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Register sampler for a field
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
prediction_attr (str): field to replace with a sampled value
|
| 80 |
+
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
| 81 |
+
func (Optional[Callable: Any -> Any]): sampler function
|
| 82 |
+
"""
|
| 83 |
+
self._samplers[(prediction_attr, gt_attr)] = _Sampler(
|
| 84 |
+
src=prediction_attr, dst=gt_attr, func=func
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def remove_sampler(
|
| 88 |
+
self,
|
| 89 |
+
prediction_attr: str,
|
| 90 |
+
gt_attr: Optional[str] = None,
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Remove sampler for a field
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
prediction_attr (str): field to replace with a sampled value
|
| 97 |
+
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
| 98 |
+
"""
|
| 99 |
+
assert (prediction_attr, gt_attr) in self._samplers
|
| 100 |
+
del self._samplers[(prediction_attr, gt_attr)]
|
CatVTON/densepose/data/transform/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .image import ImageResizeTransform
|
CatVTON/densepose/data/transform/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (219 Bytes). View file
|
|
|
CatVTON/densepose/data/transform/__pycache__/image.cpython-39.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
CatVTON/densepose/data/transform/image.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ImageResizeTransform:
|
| 9 |
+
"""
|
| 10 |
+
Transform that resizes images loaded from a dataset
|
| 11 |
+
(BGR data in NCHW channel order, typically uint8) to a format ready to be
|
| 12 |
+
consumed by DensePose training (BGR float32 data in NCHW channel order)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, min_size: int = 800, max_size: int = 1333):
|
| 16 |
+
self.min_size = min_size
|
| 17 |
+
self.max_size = max_size
|
| 18 |
+
|
| 19 |
+
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
Args:
|
| 22 |
+
images (torch.Tensor): tensor of size [N, 3, H, W] that contains
|
| 23 |
+
BGR data (typically in uint8)
|
| 24 |
+
Returns:
|
| 25 |
+
images (torch.Tensor): tensor of size [N, 3, H1, W1] where
|
| 26 |
+
H1 and W1 are chosen to respect the specified min and max sizes
|
| 27 |
+
and preserve the original aspect ratio, the data channels
|
| 28 |
+
follow BGR order and the data type is `torch.float32`
|
| 29 |
+
"""
|
| 30 |
+
# resize with min size
|
| 31 |
+
images = images.float()
|
| 32 |
+
min_size = min(images.shape[-2:])
|
| 33 |
+
max_size = max(images.shape[-2:])
|
| 34 |
+
scale = min(self.min_size / min_size, self.max_size / max_size)
|
| 35 |
+
images = torch.nn.functional.interpolate(
|
| 36 |
+
images,
|
| 37 |
+
scale_factor=scale,
|
| 38 |
+
mode="bilinear",
|
| 39 |
+
align_corners=False,
|
| 40 |
+
)
|
| 41 |
+
return images
|
CatVTON/detectron2/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from .utils.env import setup_environment
|
| 4 |
+
|
| 5 |
+
setup_environment()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# This line will be programatically read/write by setup.py.
|
| 9 |
+
# Leave them at the bottom of this file and don't touch them.
|
| 10 |
+
__version__ = "0.6"
|
CatVTON/detectron2/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# File:
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from . import catalog as _UNUSED # register the handler
|
| 7 |
+
from .detection_checkpoint import DetectionCheckpointer
|
| 8 |
+
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
| 9 |
+
|
| 10 |
+
__all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
|
CatVTON/detectron2/checkpoint/c2_model_loading.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_basic_c2_names(original_keys):
|
| 10 |
+
"""
|
| 11 |
+
Apply some basic name conversion to names in C2 weights.
|
| 12 |
+
It only deals with typical backbone models.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
original_keys (list[str]):
|
| 16 |
+
Returns:
|
| 17 |
+
list[str]: The same number of strings matching those in original_keys.
|
| 18 |
+
"""
|
| 19 |
+
layer_keys = copy.deepcopy(original_keys)
|
| 20 |
+
layer_keys = [
|
| 21 |
+
{"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
|
| 22 |
+
] # some hard-coded mappings
|
| 23 |
+
|
| 24 |
+
layer_keys = [k.replace("_", ".") for k in layer_keys]
|
| 25 |
+
layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
|
| 26 |
+
layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
|
| 27 |
+
# Uniform both bn and gn names to "norm"
|
| 28 |
+
layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
|
| 29 |
+
layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
|
| 30 |
+
layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
|
| 31 |
+
layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
|
| 32 |
+
layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
|
| 33 |
+
layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
|
| 34 |
+
layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
|
| 35 |
+
layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
|
| 36 |
+
layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
|
| 37 |
+
layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
|
| 38 |
+
|
| 39 |
+
# stem
|
| 40 |
+
layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
|
| 41 |
+
# to avoid mis-matching with "conv1" in other components (e.g. detection head)
|
| 42 |
+
layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
|
| 43 |
+
|
| 44 |
+
# layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
|
| 45 |
+
# layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
|
| 46 |
+
# layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
|
| 47 |
+
# layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
|
| 48 |
+
# layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
|
| 49 |
+
|
| 50 |
+
# blocks
|
| 51 |
+
layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
|
| 52 |
+
layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
|
| 53 |
+
layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
|
| 54 |
+
layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
|
| 55 |
+
|
| 56 |
+
# DensePose substitutions
|
| 57 |
+
layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
|
| 58 |
+
layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
|
| 59 |
+
layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
|
| 60 |
+
layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
|
| 61 |
+
layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
|
| 62 |
+
return layer_keys
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def convert_c2_detectron_names(weights):
|
| 66 |
+
"""
|
| 67 |
+
Map Caffe2 Detectron weight names to Detectron2 names.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
weights (dict): name -> tensor
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
dict: detectron2 names -> tensor
|
| 74 |
+
dict: detectron2 names -> C2 names
|
| 75 |
+
"""
|
| 76 |
+
logger = logging.getLogger(__name__)
|
| 77 |
+
logger.info("Renaming Caffe2 weights ......")
|
| 78 |
+
original_keys = sorted(weights.keys())
|
| 79 |
+
layer_keys = copy.deepcopy(original_keys)
|
| 80 |
+
|
| 81 |
+
layer_keys = convert_basic_c2_names(layer_keys)
|
| 82 |
+
|
| 83 |
+
# --------------------------------------------------------------------------
|
| 84 |
+
# RPN hidden representation conv
|
| 85 |
+
# --------------------------------------------------------------------------
|
| 86 |
+
# FPN case
|
| 87 |
+
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
|
| 88 |
+
# shared for all other levels, hence the appearance of "fpn2"
|
| 89 |
+
layer_keys = [
|
| 90 |
+
k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
|
| 91 |
+
]
|
| 92 |
+
# Non-FPN case
|
| 93 |
+
layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
|
| 94 |
+
|
| 95 |
+
# --------------------------------------------------------------------------
|
| 96 |
+
# RPN box transformation conv
|
| 97 |
+
# --------------------------------------------------------------------------
|
| 98 |
+
# FPN case (see note above about "fpn2")
|
| 99 |
+
layer_keys = [
|
| 100 |
+
k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
|
| 101 |
+
for k in layer_keys
|
| 102 |
+
]
|
| 103 |
+
layer_keys = [
|
| 104 |
+
k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
|
| 105 |
+
for k in layer_keys
|
| 106 |
+
]
|
| 107 |
+
# Non-FPN case
|
| 108 |
+
layer_keys = [
|
| 109 |
+
k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
|
| 110 |
+
]
|
| 111 |
+
layer_keys = [
|
| 112 |
+
k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
|
| 113 |
+
for k in layer_keys
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# --------------------------------------------------------------------------
|
| 117 |
+
# Fast R-CNN box head
|
| 118 |
+
# --------------------------------------------------------------------------
|
| 119 |
+
layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
|
| 120 |
+
layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
|
| 121 |
+
layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
|
| 122 |
+
layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
|
| 123 |
+
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
|
| 124 |
+
layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
|
| 125 |
+
|
| 126 |
+
# --------------------------------------------------------------------------
|
| 127 |
+
# FPN lateral and output convolutions
|
| 128 |
+
# --------------------------------------------------------------------------
|
| 129 |
+
def fpn_map(name):
|
| 130 |
+
"""
|
| 131 |
+
Look for keys with the following patterns:
|
| 132 |
+
1) Starts with "fpn.inner."
|
| 133 |
+
Example: "fpn.inner.res2.2.sum.lateral.weight"
|
| 134 |
+
Meaning: These are lateral pathway convolutions
|
| 135 |
+
2) Starts with "fpn.res"
|
| 136 |
+
Example: "fpn.res2.2.sum.weight"
|
| 137 |
+
Meaning: These are FPN output convolutions
|
| 138 |
+
"""
|
| 139 |
+
splits = name.split(".")
|
| 140 |
+
norm = ".norm" if "norm" in splits else ""
|
| 141 |
+
if name.startswith("fpn.inner."):
|
| 142 |
+
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
|
| 143 |
+
stage = int(splits[2][len("res") :])
|
| 144 |
+
return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
|
| 145 |
+
elif name.startswith("fpn.res"):
|
| 146 |
+
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
|
| 147 |
+
stage = int(splits[1][len("res") :])
|
| 148 |
+
return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
|
| 149 |
+
return name
|
| 150 |
+
|
| 151 |
+
layer_keys = [fpn_map(k) for k in layer_keys]
|
| 152 |
+
|
| 153 |
+
# --------------------------------------------------------------------------
|
| 154 |
+
# Mask R-CNN mask head
|
| 155 |
+
# --------------------------------------------------------------------------
|
| 156 |
+
# roi_heads.StandardROIHeads case
|
| 157 |
+
layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
|
| 158 |
+
layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
|
| 159 |
+
layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
|
| 160 |
+
# roi_heads.Res5ROIHeads case
|
| 161 |
+
layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
|
| 162 |
+
|
| 163 |
+
# --------------------------------------------------------------------------
|
| 164 |
+
# Keypoint R-CNN head
|
| 165 |
+
# --------------------------------------------------------------------------
|
| 166 |
+
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
|
| 167 |
+
layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
|
| 168 |
+
layer_keys = [
|
| 169 |
+
k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
|
| 170 |
+
]
|
| 171 |
+
layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
|
| 172 |
+
|
| 173 |
+
# --------------------------------------------------------------------------
|
| 174 |
+
# Done with replacements
|
| 175 |
+
# --------------------------------------------------------------------------
|
| 176 |
+
assert len(set(layer_keys)) == len(layer_keys)
|
| 177 |
+
assert len(original_keys) == len(layer_keys)
|
| 178 |
+
|
| 179 |
+
new_weights = {}
|
| 180 |
+
new_keys_to_original_keys = {}
|
| 181 |
+
for orig, renamed in zip(original_keys, layer_keys):
|
| 182 |
+
new_keys_to_original_keys[renamed] = orig
|
| 183 |
+
if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
|
| 184 |
+
# remove the meaningless prediction weight for background class
|
| 185 |
+
new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
|
| 186 |
+
new_weights[renamed] = weights[orig][new_start_idx:]
|
| 187 |
+
logger.info(
|
| 188 |
+
"Remove prediction weight for background class in {}. The shape changes from "
|
| 189 |
+
"{} to {}.".format(
|
| 190 |
+
renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
elif renamed.startswith("cls_score."):
|
| 194 |
+
# move weights of bg class from original index 0 to last index
|
| 195 |
+
logger.info(
|
| 196 |
+
"Move classification weights for background class in {} from index 0 to "
|
| 197 |
+
"index {}.".format(renamed, weights[orig].shape[0] - 1)
|
| 198 |
+
)
|
| 199 |
+
new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
|
| 200 |
+
else:
|
| 201 |
+
new_weights[renamed] = weights[orig]
|
| 202 |
+
|
| 203 |
+
return new_weights, new_keys_to_original_keys
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Note the current matching is not symmetric.
|
| 207 |
+
# it assumes model_state_dict will have longer names.
|
| 208 |
+
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
|
| 209 |
+
"""
|
| 210 |
+
Match names between the two state-dict, and returns a new chkpt_state_dict with names
|
| 211 |
+
converted to match model_state_dict with heuristics. The returned dict can be later
|
| 212 |
+
loaded with fvcore checkpointer.
|
| 213 |
+
If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
|
| 214 |
+
model and will be renamed at first.
|
| 215 |
+
|
| 216 |
+
Strategy: suppose that the models that we will create will have prefixes appended
|
| 217 |
+
to each of its keys, for example due to an extra level of nesting that the original
|
| 218 |
+
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
|
| 219 |
+
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
|
| 220 |
+
res2.conv1.weight. We thus want to match both parameters together.
|
| 221 |
+
For that, we look for each model weight, look among all loaded keys if there is one
|
| 222 |
+
that is a suffix of the current weight name, and use it if that's the case.
|
| 223 |
+
If multiple matches exist, take the one with longest size
|
| 224 |
+
of the corresponding name. For example, for the same model as before, the pretrained
|
| 225 |
+
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
|
| 226 |
+
we want to match backbone[0].body.conv1.weight to conv1.weight, and
|
| 227 |
+
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
|
| 228 |
+
"""
|
| 229 |
+
model_keys = sorted(model_state_dict.keys())
|
| 230 |
+
if c2_conversion:
|
| 231 |
+
ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
|
| 232 |
+
# original_keys: the name in the original dict (before renaming)
|
| 233 |
+
else:
|
| 234 |
+
original_keys = {x: x for x in ckpt_state_dict.keys()}
|
| 235 |
+
ckpt_keys = sorted(ckpt_state_dict.keys())
|
| 236 |
+
|
| 237 |
+
def match(a, b):
|
| 238 |
+
# Matched ckpt_key should be a complete (starts with '.') suffix.
|
| 239 |
+
# For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
|
| 240 |
+
# but matches whatever_conv1 or mesh_head.whatever_conv1.
|
| 241 |
+
return a == b or a.endswith("." + b)
|
| 242 |
+
|
| 243 |
+
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
|
| 244 |
+
# ckpt_key string, if it matches
|
| 245 |
+
match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
|
| 246 |
+
match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
|
| 247 |
+
# use the matched one with longest size in case of multiple matches
|
| 248 |
+
max_match_size, idxs = match_matrix.max(1)
|
| 249 |
+
# remove indices that correspond to no-match
|
| 250 |
+
idxs[max_match_size == 0] = -1
|
| 251 |
+
|
| 252 |
+
logger = logging.getLogger(__name__)
|
| 253 |
+
# matched_pairs (matched checkpoint key --> matched model key)
|
| 254 |
+
matched_keys = {}
|
| 255 |
+
result_state_dict = {}
|
| 256 |
+
for idx_model, idx_ckpt in enumerate(idxs.tolist()):
|
| 257 |
+
if idx_ckpt == -1:
|
| 258 |
+
continue
|
| 259 |
+
key_model = model_keys[idx_model]
|
| 260 |
+
key_ckpt = ckpt_keys[idx_ckpt]
|
| 261 |
+
value_ckpt = ckpt_state_dict[key_ckpt]
|
| 262 |
+
shape_in_model = model_state_dict[key_model].shape
|
| 263 |
+
|
| 264 |
+
if shape_in_model != value_ckpt.shape:
|
| 265 |
+
logger.warning(
|
| 266 |
+
"Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
|
| 267 |
+
key_ckpt, value_ckpt.shape, key_model, shape_in_model
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
logger.warning(
|
| 271 |
+
"{} will not be loaded. Please double check and see if this is desired.".format(
|
| 272 |
+
key_ckpt
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
assert key_model not in result_state_dict
|
| 278 |
+
result_state_dict[key_model] = value_ckpt
|
| 279 |
+
if key_ckpt in matched_keys: # already added to matched_keys
|
| 280 |
+
logger.error(
|
| 281 |
+
"Ambiguity found for {} in checkpoint!"
|
| 282 |
+
"It matches at least two keys in the model ({} and {}).".format(
|
| 283 |
+
key_ckpt, key_model, matched_keys[key_ckpt]
|
| 284 |
+
)
|
| 285 |
+
)
|
| 286 |
+
raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
|
| 287 |
+
|
| 288 |
+
matched_keys[key_ckpt] = key_model
|
| 289 |
+
|
| 290 |
+
# logging:
|
| 291 |
+
matched_model_keys = sorted(matched_keys.values())
|
| 292 |
+
if len(matched_model_keys) == 0:
|
| 293 |
+
logger.warning("No weights in checkpoint matched with model.")
|
| 294 |
+
return ckpt_state_dict
|
| 295 |
+
common_prefix = _longest_common_prefix(matched_model_keys)
|
| 296 |
+
rev_matched_keys = {v: k for k, v in matched_keys.items()}
|
| 297 |
+
original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
|
| 298 |
+
|
| 299 |
+
model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
|
| 300 |
+
table = []
|
| 301 |
+
memo = set()
|
| 302 |
+
for key_model in matched_model_keys:
|
| 303 |
+
if key_model in memo:
|
| 304 |
+
continue
|
| 305 |
+
if key_model in model_key_groups:
|
| 306 |
+
group = model_key_groups[key_model]
|
| 307 |
+
memo |= set(group)
|
| 308 |
+
shapes = [tuple(model_state_dict[k].shape) for k in group]
|
| 309 |
+
table.append(
|
| 310 |
+
(
|
| 311 |
+
_longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
|
| 312 |
+
_group_str([original_keys[k] for k in group]),
|
| 313 |
+
" ".join([str(x).replace(" ", "") for x in shapes]),
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
key_checkpoint = original_keys[key_model]
|
| 318 |
+
shape = str(tuple(model_state_dict[key_model].shape))
|
| 319 |
+
table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
|
| 320 |
+
submodule_str = common_prefix[:-1] if common_prefix else "model"
|
| 321 |
+
logger.info(
|
| 322 |
+
f"Following weights matched with submodule {submodule_str} - Total num: {len(table)}"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
|
| 326 |
+
for k in unmatched_ckpt_keys:
|
| 327 |
+
result_state_dict[k] = ckpt_state_dict[k]
|
| 328 |
+
return result_state_dict
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
|
| 332 |
+
"""
|
| 333 |
+
Params in the same submodule are grouped together.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
keys: names of all parameters
|
| 337 |
+
original_names: mapping from parameter name to their name in the checkpoint
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
dict[name -> all other names in the same group]
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
def _submodule_name(key):
|
| 344 |
+
pos = key.rfind(".")
|
| 345 |
+
if pos < 0:
|
| 346 |
+
return None
|
| 347 |
+
prefix = key[: pos + 1]
|
| 348 |
+
return prefix
|
| 349 |
+
|
| 350 |
+
all_submodules = [_submodule_name(k) for k in keys]
|
| 351 |
+
all_submodules = [x for x in all_submodules if x]
|
| 352 |
+
all_submodules = sorted(all_submodules, key=len)
|
| 353 |
+
|
| 354 |
+
ret = {}
|
| 355 |
+
for prefix in all_submodules:
|
| 356 |
+
group = [k for k in keys if k.startswith(prefix)]
|
| 357 |
+
if len(group) <= 1:
|
| 358 |
+
continue
|
| 359 |
+
original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
|
| 360 |
+
if len(original_name_lcp) == 0:
|
| 361 |
+
# don't group weights if original names don't share prefix
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
for k in group:
|
| 365 |
+
if k in ret:
|
| 366 |
+
continue
|
| 367 |
+
ret[k] = group
|
| 368 |
+
return ret
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def _longest_common_prefix(names: List[str]) -> str:
|
| 372 |
+
"""
|
| 373 |
+
["abc.zfg", "abc.zef"] -> "abc."
|
| 374 |
+
"""
|
| 375 |
+
names = [n.split(".") for n in names]
|
| 376 |
+
m1, m2 = min(names), max(names)
|
| 377 |
+
ret = [a for a, b in zip(m1, m2) if a == b]
|
| 378 |
+
ret = ".".join(ret) + "." if len(ret) else ""
|
| 379 |
+
return ret
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _longest_common_prefix_str(names: List[str]) -> str:
|
| 383 |
+
m1, m2 = min(names), max(names)
|
| 384 |
+
lcp = []
|
| 385 |
+
for a, b in zip(m1, m2):
|
| 386 |
+
if a == b:
|
| 387 |
+
lcp.append(a)
|
| 388 |
+
else:
|
| 389 |
+
break
|
| 390 |
+
lcp = "".join(lcp)
|
| 391 |
+
return lcp
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def _group_str(names: List[str]) -> str:
|
| 395 |
+
"""
|
| 396 |
+
Turn "common1", "common2", "common3" into "common{1,2,3}"
|
| 397 |
+
"""
|
| 398 |
+
lcp = _longest_common_prefix_str(names)
|
| 399 |
+
rest = [x[len(lcp) :] for x in names]
|
| 400 |
+
rest = "{" + ",".join(rest) + "}"
|
| 401 |
+
ret = lcp + rest
|
| 402 |
+
|
| 403 |
+
# add some simplification for BN specifically
|
| 404 |
+
ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
|
| 405 |
+
ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
|
| 406 |
+
return ret
|
CatVTON/detectron2/checkpoint/catalog.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from detectron2.utils.file_io import PathHandler, PathManager
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModelCatalog:
|
| 8 |
+
"""
|
| 9 |
+
Store mappings from names to third-party models.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
|
| 13 |
+
|
| 14 |
+
# MSRA models have STRIDE_IN_1X1=True. False otherwise.
|
| 15 |
+
# NOTE: all BN models here have fused BN into an affine layer.
|
| 16 |
+
# As a result, you should only load them to a model with "FrozenBN".
|
| 17 |
+
# Loading them to a model with regular BN or SyncBN is wrong.
|
| 18 |
+
# Even when loaded to FrozenBN, it is still different from affine by an epsilon,
|
| 19 |
+
# which should be negligible for training.
|
| 20 |
+
# NOTE: all models here uses PIXEL_STD=[1,1,1]
|
| 21 |
+
# NOTE: Most of the BN models here are no longer used. We use the
|
| 22 |
+
# re-converted pre-trained models under detectron2 model zoo instead.
|
| 23 |
+
C2_IMAGENET_MODELS = {
|
| 24 |
+
"MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
|
| 25 |
+
"MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
|
| 26 |
+
"FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
|
| 27 |
+
"FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
|
| 28 |
+
"FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
|
| 29 |
+
"FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
|
| 30 |
+
"FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
C2_DETECTRON_PATH_FORMAT = (
|
| 34 |
+
"{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
|
| 38 |
+
C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
|
| 39 |
+
|
| 40 |
+
# format: {model_name} -> part of the url
|
| 41 |
+
C2_DETECTRON_MODELS = {
|
| 42 |
+
"35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
|
| 43 |
+
"35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
|
| 44 |
+
"35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
|
| 45 |
+
"36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
|
| 46 |
+
"35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
|
| 47 |
+
"35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
|
| 48 |
+
"35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
|
| 49 |
+
"36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
|
| 50 |
+
"48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
|
| 51 |
+
"37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
|
| 52 |
+
"35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
|
| 53 |
+
"35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
|
| 54 |
+
"36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def get(name):
|
| 59 |
+
if name.startswith("Caffe2Detectron/COCO"):
|
| 60 |
+
return ModelCatalog._get_c2_detectron_baseline(name)
|
| 61 |
+
if name.startswith("ImageNetPretrained/"):
|
| 62 |
+
return ModelCatalog._get_c2_imagenet_pretrained(name)
|
| 63 |
+
raise RuntimeError("model not present in the catalog: {}".format(name))
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def _get_c2_imagenet_pretrained(name):
|
| 67 |
+
prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
|
| 68 |
+
name = name[len("ImageNetPretrained/") :]
|
| 69 |
+
name = ModelCatalog.C2_IMAGENET_MODELS[name]
|
| 70 |
+
url = "/".join([prefix, name])
|
| 71 |
+
return url
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def _get_c2_detectron_baseline(name):
|
| 75 |
+
name = name[len("Caffe2Detectron/COCO/") :]
|
| 76 |
+
url = ModelCatalog.C2_DETECTRON_MODELS[name]
|
| 77 |
+
if "keypoint_rcnn" in name:
|
| 78 |
+
dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
|
| 79 |
+
else:
|
| 80 |
+
dataset = ModelCatalog.C2_DATASET_COCO
|
| 81 |
+
|
| 82 |
+
if "35998355/rpn_R-50-C4_1x" in name:
|
| 83 |
+
# this one model is somehow different from others ..
|
| 84 |
+
type = "rpn"
|
| 85 |
+
else:
|
| 86 |
+
type = "generalized_rcnn"
|
| 87 |
+
|
| 88 |
+
# Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
|
| 89 |
+
url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
|
| 90 |
+
prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
|
| 91 |
+
)
|
| 92 |
+
return url
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ModelCatalogHandler(PathHandler):
|
| 96 |
+
"""
|
| 97 |
+
Resolve URL like catalog://.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
PREFIX = "catalog://"
|
| 101 |
+
|
| 102 |
+
def _get_supported_prefixes(self):
|
| 103 |
+
return [self.PREFIX]
|
| 104 |
+
|
| 105 |
+
def _get_local_path(self, path, **kwargs):
|
| 106 |
+
logger = logging.getLogger(__name__)
|
| 107 |
+
catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
|
| 108 |
+
logger.info("Catalog entry {} points to {}".format(path, catalog_path))
|
| 109 |
+
return PathManager.get_local_path(catalog_path, **kwargs)
|
| 110 |
+
|
| 111 |
+
def _open(self, path, mode="r", **kwargs):
|
| 112 |
+
return PathManager.open(self._get_local_path(path), mode, **kwargs)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
PathManager.register_handler(ModelCatalogHandler())
|
CatVTON/detectron2/checkpoint/detection_checkpoint.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
from urllib.parse import parse_qs, urlparse
|
| 6 |
+
import torch
|
| 7 |
+
from fvcore.common.checkpoint import Checkpointer
|
| 8 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 9 |
+
|
| 10 |
+
import detectron2.utils.comm as comm
|
| 11 |
+
from detectron2.utils.file_io import PathManager
|
| 12 |
+
|
| 13 |
+
from .c2_model_loading import align_and_update_state_dicts
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DetectionCheckpointer(Checkpointer):
|
| 17 |
+
"""
|
| 18 |
+
Same as :class:`Checkpointer`, but is able to:
|
| 19 |
+
1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models.
|
| 20 |
+
2. correctly load checkpoints that are only available on the master worker
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
|
| 24 |
+
is_main_process = comm.is_main_process()
|
| 25 |
+
super().__init__(
|
| 26 |
+
model,
|
| 27 |
+
save_dir,
|
| 28 |
+
save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
|
| 29 |
+
**checkpointables,
|
| 30 |
+
)
|
| 31 |
+
self.path_manager = PathManager
|
| 32 |
+
self._parsed_url_during_load = None
|
| 33 |
+
|
| 34 |
+
def load(self, path, *args, **kwargs):
|
| 35 |
+
assert self._parsed_url_during_load is None
|
| 36 |
+
need_sync = False
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
logger.info("[DetectionCheckpointer] Loading from {} ...".format(path))
|
| 39 |
+
|
| 40 |
+
if path and isinstance(self.model, DistributedDataParallel):
|
| 41 |
+
path = self.path_manager.get_local_path(path)
|
| 42 |
+
has_file = os.path.isfile(path)
|
| 43 |
+
all_has_file = comm.all_gather(has_file)
|
| 44 |
+
if not all_has_file[0]:
|
| 45 |
+
raise OSError(f"File {path} not found on main worker.")
|
| 46 |
+
if not all(all_has_file):
|
| 47 |
+
logger.warning(
|
| 48 |
+
f"Not all workers can read checkpoint {path}. "
|
| 49 |
+
"Training may fail to fully resume."
|
| 50 |
+
)
|
| 51 |
+
# TODO: broadcast the checkpoint file contents from main
|
| 52 |
+
# worker, and load from it instead.
|
| 53 |
+
need_sync = True
|
| 54 |
+
if not has_file:
|
| 55 |
+
path = None # don't load if not readable
|
| 56 |
+
|
| 57 |
+
if path:
|
| 58 |
+
parsed_url = urlparse(path)
|
| 59 |
+
self._parsed_url_during_load = parsed_url
|
| 60 |
+
path = parsed_url._replace(query="").geturl() # remove query from filename
|
| 61 |
+
path = self.path_manager.get_local_path(path)
|
| 62 |
+
ret = super().load(path, *args, **kwargs)
|
| 63 |
+
|
| 64 |
+
if need_sync:
|
| 65 |
+
logger.info("Broadcasting model states from main worker ...")
|
| 66 |
+
self.model._sync_params_and_buffers()
|
| 67 |
+
self._parsed_url_during_load = None # reset to None
|
| 68 |
+
return ret
|
| 69 |
+
|
| 70 |
+
def _load_file(self, filename):
|
| 71 |
+
if filename.endswith(".pkl"):
|
| 72 |
+
with PathManager.open(filename, "rb") as f:
|
| 73 |
+
data = pickle.load(f, encoding="latin1")
|
| 74 |
+
if "model" in data and "__author__" in data:
|
| 75 |
+
# file is in Detectron2 model zoo format
|
| 76 |
+
self.logger.info("Reading a file from '{}'".format(data["__author__"]))
|
| 77 |
+
return data
|
| 78 |
+
else:
|
| 79 |
+
# assume file is from Caffe2 / Detectron1 model zoo
|
| 80 |
+
if "blobs" in data:
|
| 81 |
+
# Detection models have "blobs", but ImageNet models don't
|
| 82 |
+
data = data["blobs"]
|
| 83 |
+
data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
|
| 84 |
+
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
|
| 85 |
+
elif filename.endswith(".pyth"):
|
| 86 |
+
# assume file is from pycls; no one else seems to use the ".pyth" extension
|
| 87 |
+
with PathManager.open(filename, "rb") as f:
|
| 88 |
+
data = torch.load(f)
|
| 89 |
+
assert (
|
| 90 |
+
"model_state" in data
|
| 91 |
+
), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'."
|
| 92 |
+
model_state = {
|
| 93 |
+
k: v
|
| 94 |
+
for k, v in data["model_state"].items()
|
| 95 |
+
if not k.endswith("num_batches_tracked")
|
| 96 |
+
}
|
| 97 |
+
return {"model": model_state, "__author__": "pycls", "matching_heuristics": True}
|
| 98 |
+
|
| 99 |
+
loaded = self._torch_load(filename)
|
| 100 |
+
if "model" not in loaded:
|
| 101 |
+
loaded = {"model": loaded}
|
| 102 |
+
assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`"
|
| 103 |
+
parsed_url = self._parsed_url_during_load
|
| 104 |
+
queries = parse_qs(parsed_url.query)
|
| 105 |
+
if queries.pop("matching_heuristics", "False") == ["True"]:
|
| 106 |
+
loaded["matching_heuristics"] = True
|
| 107 |
+
if len(queries) > 0:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}"
|
| 110 |
+
)
|
| 111 |
+
return loaded
|
| 112 |
+
|
| 113 |
+
def _torch_load(self, f):
|
| 114 |
+
return super()._load_file(f)
|
| 115 |
+
|
| 116 |
+
def _load_model(self, checkpoint):
|
| 117 |
+
if checkpoint.get("matching_heuristics", False):
|
| 118 |
+
self._convert_ndarray_to_tensor(checkpoint["model"])
|
| 119 |
+
# convert weights by name-matching heuristics
|
| 120 |
+
checkpoint["model"] = align_and_update_state_dicts(
|
| 121 |
+
self.model.state_dict(),
|
| 122 |
+
checkpoint["model"],
|
| 123 |
+
c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
|
| 124 |
+
)
|
| 125 |
+
# for non-caffe2 models, use standard ways to load it
|
| 126 |
+
incompatible = super()._load_model(checkpoint)
|
| 127 |
+
|
| 128 |
+
model_buffers = dict(self.model.named_buffers(recurse=False))
|
| 129 |
+
for k in ["pixel_mean", "pixel_std"]:
|
| 130 |
+
# Ignore missing key message about pixel_mean/std.
|
| 131 |
+
# Though they may be missing in old checkpoints, they will be correctly
|
| 132 |
+
# initialized from config anyway.
|
| 133 |
+
if k in model_buffers:
|
| 134 |
+
try:
|
| 135 |
+
incompatible.missing_keys.remove(k)
|
| 136 |
+
except ValueError:
|
| 137 |
+
pass
|
| 138 |
+
for k in incompatible.unexpected_keys[:]:
|
| 139 |
+
# Ignore unexpected keys about cell anchors. They exist in old checkpoints
|
| 140 |
+
# but now they are non-persistent buffers and will not be in new checkpoints.
|
| 141 |
+
if "anchor_generator.cell_anchors" in k:
|
| 142 |
+
incompatible.unexpected_keys.remove(k)
|
| 143 |
+
return incompatible
|
CatVTON/detectron2/engine/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from .launch import *
|
| 4 |
+
from .train_loop import *
|
| 5 |
+
|
| 6 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
|
| 10 |
+
# but still make them available here
|
| 11 |
+
from .hooks import *
|
| 12 |
+
from .defaults import (
|
| 13 |
+
create_ddp_model,
|
| 14 |
+
default_argument_parser,
|
| 15 |
+
default_setup,
|
| 16 |
+
default_writers,
|
| 17 |
+
DefaultPredictor,
|
| 18 |
+
DefaultTrainer,
|
| 19 |
+
)
|
CatVTON/detectron2/engine/defaults.py
ADDED
|
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
This file contains components with some default boilerplate logic user may need
|
| 6 |
+
in training / testing. They will not work for everyone, but many users may find them useful.
|
| 7 |
+
|
| 8 |
+
The behavior of functions/classes in this file is subject to change,
|
| 9 |
+
since they are meant to represent the "common default behavior" people need in their projects.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import weakref
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Optional
|
| 19 |
+
import torch
|
| 20 |
+
from fvcore.nn.precise_bn import get_bn_modules
|
| 21 |
+
from omegaconf import OmegaConf
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 23 |
+
|
| 24 |
+
import detectron2.data.transforms as T
|
| 25 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 26 |
+
from detectron2.config import CfgNode, LazyConfig
|
| 27 |
+
from detectron2.data import (
|
| 28 |
+
MetadataCatalog,
|
| 29 |
+
build_detection_test_loader,
|
| 30 |
+
build_detection_train_loader,
|
| 31 |
+
)
|
| 32 |
+
from detectron2.evaluation import (
|
| 33 |
+
DatasetEvaluator,
|
| 34 |
+
inference_on_dataset,
|
| 35 |
+
print_csv_format,
|
| 36 |
+
verify_results,
|
| 37 |
+
)
|
| 38 |
+
from detectron2.modeling import build_model
|
| 39 |
+
from detectron2.solver import build_lr_scheduler, build_optimizer
|
| 40 |
+
from detectron2.utils import comm
|
| 41 |
+
from detectron2.utils.collect_env import collect_env_info
|
| 42 |
+
from detectron2.utils.env import seed_all_rng
|
| 43 |
+
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
| 44 |
+
from detectron2.utils.file_io import PathManager
|
| 45 |
+
from detectron2.utils.logger import setup_logger
|
| 46 |
+
|
| 47 |
+
from . import hooks
|
| 48 |
+
from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
|
| 49 |
+
|
| 50 |
+
__all__ = [
|
| 51 |
+
"create_ddp_model",
|
| 52 |
+
"default_argument_parser",
|
| 53 |
+
"default_setup",
|
| 54 |
+
"default_writers",
|
| 55 |
+
"DefaultPredictor",
|
| 56 |
+
"DefaultTrainer",
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
| 61 |
+
"""
|
| 62 |
+
Create a DistributedDataParallel model if there are >1 processes.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model: a torch.nn.Module
|
| 66 |
+
fp16_compression: add fp16 compression hooks to the ddp object.
|
| 67 |
+
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
| 68 |
+
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
| 69 |
+
""" # noqa
|
| 70 |
+
if comm.get_world_size() == 1:
|
| 71 |
+
return model
|
| 72 |
+
if "device_ids" not in kwargs:
|
| 73 |
+
kwargs["device_ids"] = [comm.get_local_rank()]
|
| 74 |
+
ddp = DistributedDataParallel(model, **kwargs)
|
| 75 |
+
if fp16_compression:
|
| 76 |
+
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
| 77 |
+
|
| 78 |
+
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
| 79 |
+
return ddp
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def default_argument_parser(epilog=None):
|
| 83 |
+
"""
|
| 84 |
+
Create a parser with some common arguments used by detectron2 users.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
epilog (str): epilog passed to ArgumentParser describing the usage.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
argparse.ArgumentParser:
|
| 91 |
+
"""
|
| 92 |
+
parser = argparse.ArgumentParser(
|
| 93 |
+
epilog=epilog
|
| 94 |
+
or f"""
|
| 95 |
+
Examples:
|
| 96 |
+
|
| 97 |
+
Run on single machine:
|
| 98 |
+
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
| 99 |
+
|
| 100 |
+
Change some config options:
|
| 101 |
+
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
| 102 |
+
|
| 103 |
+
Run on multiple machines:
|
| 104 |
+
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
| 105 |
+
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
| 106 |
+
""",
|
| 107 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--resume",
|
| 112 |
+
action="store_true",
|
| 113 |
+
help="Whether to attempt to resume from the checkpoint directory. "
|
| 114 |
+
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
| 117 |
+
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
| 118 |
+
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# PyTorch still may leave orphan processes in multi-gpu training.
|
| 124 |
+
# Therefore we use a deterministic way to obtain port,
|
| 125 |
+
# so that users are aware of orphan processes by seeing the port occupied.
|
| 126 |
+
port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--dist-url",
|
| 129 |
+
default="tcp://127.0.0.1:{}".format(port),
|
| 130 |
+
help="initialization URL for pytorch distributed backend. See "
|
| 131 |
+
"https://pytorch.org/docs/stable/distributed.html for details.",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"opts",
|
| 135 |
+
help="""
|
| 136 |
+
Modify config options at the end of the command. For Yacs configs, use
|
| 137 |
+
space-separated "PATH.KEY VALUE" pairs.
|
| 138 |
+
For python-based LazyConfig, use "path.key=value".
|
| 139 |
+
""".strip(),
|
| 140 |
+
default=None,
|
| 141 |
+
nargs=argparse.REMAINDER,
|
| 142 |
+
)
|
| 143 |
+
return parser
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _try_get_key(cfg, *keys, default=None):
|
| 147 |
+
"""
|
| 148 |
+
Try select keys from cfg until the first key that exists. Otherwise return default.
|
| 149 |
+
"""
|
| 150 |
+
if isinstance(cfg, CfgNode):
|
| 151 |
+
cfg = OmegaConf.create(cfg.dump())
|
| 152 |
+
for k in keys:
|
| 153 |
+
none = object()
|
| 154 |
+
p = OmegaConf.select(cfg, k, default=none)
|
| 155 |
+
if p is not none:
|
| 156 |
+
return p
|
| 157 |
+
return default
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _highlight(code, filename):
|
| 161 |
+
try:
|
| 162 |
+
import pygments
|
| 163 |
+
except ImportError:
|
| 164 |
+
return code
|
| 165 |
+
|
| 166 |
+
from pygments.lexers import Python3Lexer, YamlLexer
|
| 167 |
+
from pygments.formatters import Terminal256Formatter
|
| 168 |
+
|
| 169 |
+
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
|
| 170 |
+
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
|
| 171 |
+
return code
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def default_setup(cfg, args):
|
| 175 |
+
"""
|
| 176 |
+
Perform some basic common setups at the beginning of a job, including:
|
| 177 |
+
|
| 178 |
+
1. Set up the detectron2 logger
|
| 179 |
+
2. Log basic information about environment, cmdline arguments, and config
|
| 180 |
+
3. Backup the config to the output directory
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
|
| 184 |
+
args (argparse.NameSpace): the command line arguments to be logged
|
| 185 |
+
"""
|
| 186 |
+
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
|
| 187 |
+
if comm.is_main_process() and output_dir:
|
| 188 |
+
PathManager.mkdirs(output_dir)
|
| 189 |
+
|
| 190 |
+
rank = comm.get_rank()
|
| 191 |
+
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
| 192 |
+
logger = setup_logger(output_dir, distributed_rank=rank)
|
| 193 |
+
|
| 194 |
+
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
| 195 |
+
logger.info("Environment info:\n" + collect_env_info())
|
| 196 |
+
|
| 197 |
+
logger.info("Command line arguments: " + str(args))
|
| 198 |
+
if hasattr(args, "config_file") and args.config_file != "":
|
| 199 |
+
logger.info(
|
| 200 |
+
"Contents of args.config_file={}:\n{}".format(
|
| 201 |
+
args.config_file,
|
| 202 |
+
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if comm.is_main_process() and output_dir:
|
| 207 |
+
# Note: some of our scripts may expect the existence of
|
| 208 |
+
# config.yaml in output directory
|
| 209 |
+
path = os.path.join(output_dir, "config.yaml")
|
| 210 |
+
if isinstance(cfg, CfgNode):
|
| 211 |
+
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
|
| 212 |
+
with PathManager.open(path, "w") as f:
|
| 213 |
+
f.write(cfg.dump())
|
| 214 |
+
else:
|
| 215 |
+
LazyConfig.save(cfg, path)
|
| 216 |
+
logger.info("Full config saved to {}".format(path))
|
| 217 |
+
|
| 218 |
+
# make sure each worker has a different, yet deterministic seed if specified
|
| 219 |
+
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
|
| 220 |
+
seed_all_rng(None if seed < 0 else seed + rank)
|
| 221 |
+
|
| 222 |
+
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
|
| 223 |
+
# typical validation set.
|
| 224 |
+
if not (hasattr(args, "eval_only") and args.eval_only):
|
| 225 |
+
torch.backends.cudnn.benchmark = _try_get_key(
|
| 226 |
+
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def default_writers(output_dir: str, max_iter: Optional[int] = None):
|
| 231 |
+
"""
|
| 232 |
+
Build a list of :class:`EventWriter` to be used.
|
| 233 |
+
It now consists of a :class:`CommonMetricPrinter`,
|
| 234 |
+
:class:`TensorboardXWriter` and :class:`JSONWriter`.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
output_dir: directory to store JSON metrics and tensorboard events
|
| 238 |
+
max_iter: the total number of iterations
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
| 242 |
+
"""
|
| 243 |
+
PathManager.mkdirs(output_dir)
|
| 244 |
+
return [
|
| 245 |
+
# It may not always print what you want to see, since it prints "common" metrics only.
|
| 246 |
+
CommonMetricPrinter(max_iter),
|
| 247 |
+
JSONWriter(os.path.join(output_dir, "metrics.json")),
|
| 248 |
+
TensorboardXWriter(output_dir),
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class DefaultPredictor:
|
| 253 |
+
"""
|
| 254 |
+
Create a simple end-to-end predictor with the given config that runs on
|
| 255 |
+
single device for a single input image.
|
| 256 |
+
|
| 257 |
+
Compared to using the model directly, this class does the following additions:
|
| 258 |
+
|
| 259 |
+
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
|
| 260 |
+
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
|
| 261 |
+
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
|
| 262 |
+
4. Take one input image and produce a single output, instead of a batch.
|
| 263 |
+
|
| 264 |
+
This is meant for simple demo purposes, so it does the above steps automatically.
|
| 265 |
+
This is not meant for benchmarks or running complicated inference logic.
|
| 266 |
+
If you'd like to do anything more complicated, please refer to its source code as
|
| 267 |
+
examples to build and use the model manually.
|
| 268 |
+
|
| 269 |
+
Attributes:
|
| 270 |
+
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
| 271 |
+
cfg.DATASETS.TEST.
|
| 272 |
+
|
| 273 |
+
Examples:
|
| 274 |
+
::
|
| 275 |
+
pred = DefaultPredictor(cfg)
|
| 276 |
+
inputs = cv2.imread("input.jpg")
|
| 277 |
+
outputs = pred(inputs)
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
def __init__(self, cfg):
|
| 281 |
+
self.cfg = cfg.clone() # cfg can be modified by model
|
| 282 |
+
self.model = build_model(self.cfg)
|
| 283 |
+
self.model.eval()
|
| 284 |
+
if len(cfg.DATASETS.TEST):
|
| 285 |
+
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
| 286 |
+
|
| 287 |
+
checkpointer = DetectionCheckpointer(self.model)
|
| 288 |
+
checkpointer.load(cfg.MODEL.WEIGHTS)
|
| 289 |
+
|
| 290 |
+
self.aug = T.ResizeShortestEdge(
|
| 291 |
+
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
self.input_format = cfg.INPUT.FORMAT
|
| 295 |
+
assert self.input_format in ["RGB", "BGR"], self.input_format
|
| 296 |
+
|
| 297 |
+
def __call__(self, original_image):
|
| 298 |
+
"""
|
| 299 |
+
Args:
|
| 300 |
+
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
predictions (dict):
|
| 304 |
+
the output of the model for one image only.
|
| 305 |
+
See :doc:`/tutorials/models` for details about the format.
|
| 306 |
+
"""
|
| 307 |
+
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
| 308 |
+
# Apply pre-processing to image.
|
| 309 |
+
if self.input_format == "RGB":
|
| 310 |
+
# whether the model expects BGR inputs or RGB
|
| 311 |
+
original_image = original_image[:, :, ::-1]
|
| 312 |
+
height, width = original_image.shape[:2]
|
| 313 |
+
image = self.aug.get_transform(original_image).apply_image(original_image)
|
| 314 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
| 315 |
+
image.to(self.cfg.MODEL.DEVICE)
|
| 316 |
+
|
| 317 |
+
inputs = {"image": image, "height": height, "width": width}
|
| 318 |
+
|
| 319 |
+
predictions = self.model([inputs])[0]
|
| 320 |
+
return predictions
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class DefaultTrainer(TrainerBase):
|
| 324 |
+
"""
|
| 325 |
+
A trainer with default training logic. It does the following:
|
| 326 |
+
|
| 327 |
+
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
|
| 328 |
+
defined by the given config. Create a LR scheduler defined by the config.
|
| 329 |
+
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
|
| 330 |
+
`resume_or_load` is called.
|
| 331 |
+
3. Register a few common hooks defined by the config.
|
| 332 |
+
|
| 333 |
+
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
| 334 |
+
for users who only need the standard training workflow, with standard features.
|
| 335 |
+
It means this class makes *many assumptions* about your training logic that
|
| 336 |
+
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
| 337 |
+
:class:`SimpleTrainer` are too much for research.
|
| 338 |
+
|
| 339 |
+
The code of this class has been annotated about restrictive assumptions it makes.
|
| 340 |
+
When they do not work for you, you're encouraged to:
|
| 341 |
+
|
| 342 |
+
1. Overwrite methods of this class, OR:
|
| 343 |
+
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
| 344 |
+
nothing else. You can then add your own hooks if needed. OR:
|
| 345 |
+
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
| 346 |
+
|
| 347 |
+
See the :doc:`/tutorials/training` tutorials for more details.
|
| 348 |
+
|
| 349 |
+
Note that the behavior of this class, like other functions/classes in
|
| 350 |
+
this file, is not stable, since it is meant to represent the "common default behavior".
|
| 351 |
+
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
| 352 |
+
To obtain more stable behavior, write your own training logic with other public APIs.
|
| 353 |
+
|
| 354 |
+
Examples:
|
| 355 |
+
::
|
| 356 |
+
trainer = DefaultTrainer(cfg)
|
| 357 |
+
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
| 358 |
+
trainer.train()
|
| 359 |
+
|
| 360 |
+
Attributes:
|
| 361 |
+
scheduler:
|
| 362 |
+
checkpointer (DetectionCheckpointer):
|
| 363 |
+
cfg (CfgNode):
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, cfg):
|
| 367 |
+
"""
|
| 368 |
+
Args:
|
| 369 |
+
cfg (CfgNode):
|
| 370 |
+
"""
|
| 371 |
+
super().__init__()
|
| 372 |
+
logger = logging.getLogger("detectron2")
|
| 373 |
+
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
| 374 |
+
setup_logger()
|
| 375 |
+
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
|
| 376 |
+
|
| 377 |
+
# Assume these objects must be constructed in this order.
|
| 378 |
+
model = self.build_model(cfg)
|
| 379 |
+
optimizer = self.build_optimizer(cfg, model)
|
| 380 |
+
data_loader = self.build_train_loader(cfg)
|
| 381 |
+
|
| 382 |
+
model = create_ddp_model(model, broadcast_buffers=False)
|
| 383 |
+
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
| 384 |
+
model, data_loader, optimizer
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
| 388 |
+
self.checkpointer = DetectionCheckpointer(
|
| 389 |
+
# Assume you want to save checkpoints together with logs/statistics
|
| 390 |
+
model,
|
| 391 |
+
cfg.OUTPUT_DIR,
|
| 392 |
+
trainer=weakref.proxy(self),
|
| 393 |
+
)
|
| 394 |
+
self.start_iter = 0
|
| 395 |
+
self.max_iter = cfg.SOLVER.MAX_ITER
|
| 396 |
+
self.cfg = cfg
|
| 397 |
+
|
| 398 |
+
self.register_hooks(self.build_hooks())
|
| 399 |
+
|
| 400 |
+
def resume_or_load(self, resume=True):
|
| 401 |
+
"""
|
| 402 |
+
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
|
| 403 |
+
a `last_checkpoint` file), resume from the file. Resuming means loading all
|
| 404 |
+
available states (eg. optimizer and scheduler) and update iteration counter
|
| 405 |
+
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
|
| 406 |
+
|
| 407 |
+
Otherwise, this is considered as an independent training. The method will load model
|
| 408 |
+
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
|
| 409 |
+
from iteration 0.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
resume (bool): whether to do resume or not
|
| 413 |
+
"""
|
| 414 |
+
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
| 415 |
+
if resume and self.checkpointer.has_checkpoint():
|
| 416 |
+
# The checkpoint stores the training iteration that just finished, thus we start
|
| 417 |
+
# at the next iteration
|
| 418 |
+
self.start_iter = self.iter + 1
|
| 419 |
+
|
| 420 |
+
def build_hooks(self):
|
| 421 |
+
"""
|
| 422 |
+
Build a list of default hooks, including timing, evaluation,
|
| 423 |
+
checkpointing, lr scheduling, precise BN, writing events.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
list[HookBase]:
|
| 427 |
+
"""
|
| 428 |
+
cfg = self.cfg.clone()
|
| 429 |
+
cfg.defrost()
|
| 430 |
+
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
| 431 |
+
|
| 432 |
+
ret = [
|
| 433 |
+
hooks.IterationTimer(),
|
| 434 |
+
hooks.LRScheduler(),
|
| 435 |
+
(
|
| 436 |
+
hooks.PreciseBN(
|
| 437 |
+
# Run at the same freq as (but before) evaluation.
|
| 438 |
+
cfg.TEST.EVAL_PERIOD,
|
| 439 |
+
self.model,
|
| 440 |
+
# Build a new data loader to not affect training
|
| 441 |
+
self.build_train_loader(cfg),
|
| 442 |
+
cfg.TEST.PRECISE_BN.NUM_ITER,
|
| 443 |
+
)
|
| 444 |
+
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
| 445 |
+
else None
|
| 446 |
+
),
|
| 447 |
+
]
|
| 448 |
+
|
| 449 |
+
# Do PreciseBN before checkpointer, because it updates the model and need to
|
| 450 |
+
# be saved by checkpointer.
|
| 451 |
+
# This is not always the best: if checkpointing has a different frequency,
|
| 452 |
+
# some checkpoints may have more precise statistics than others.
|
| 453 |
+
if comm.is_main_process():
|
| 454 |
+
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
| 455 |
+
|
| 456 |
+
def test_and_save_results():
|
| 457 |
+
self._last_eval_results = self.test(self.cfg, self.model)
|
| 458 |
+
return self._last_eval_results
|
| 459 |
+
|
| 460 |
+
# Do evaluation after checkpointer, because then if it fails,
|
| 461 |
+
# we can use the saved checkpoint to debug.
|
| 462 |
+
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
| 463 |
+
|
| 464 |
+
if comm.is_main_process():
|
| 465 |
+
# Here the default print/log frequency of each writer is used.
|
| 466 |
+
# run writers in the end, so that evaluation metrics are written
|
| 467 |
+
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
|
| 468 |
+
return ret
|
| 469 |
+
|
| 470 |
+
def build_writers(self):
|
| 471 |
+
"""
|
| 472 |
+
Build a list of writers to be used using :func:`default_writers()`.
|
| 473 |
+
If you'd like a different list of writers, you can overwrite it in
|
| 474 |
+
your trainer.
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
| 478 |
+
"""
|
| 479 |
+
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
|
| 480 |
+
|
| 481 |
+
def train(self):
|
| 482 |
+
"""
|
| 483 |
+
Run training.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
| 487 |
+
"""
|
| 488 |
+
super().train(self.start_iter, self.max_iter)
|
| 489 |
+
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
|
| 490 |
+
assert hasattr(
|
| 491 |
+
self, "_last_eval_results"
|
| 492 |
+
), "No evaluation results obtained during training!"
|
| 493 |
+
verify_results(self.cfg, self._last_eval_results)
|
| 494 |
+
return self._last_eval_results
|
| 495 |
+
|
| 496 |
+
def run_step(self):
|
| 497 |
+
self._trainer.iter = self.iter
|
| 498 |
+
self._trainer.run_step()
|
| 499 |
+
|
| 500 |
+
def state_dict(self):
|
| 501 |
+
ret = super().state_dict()
|
| 502 |
+
ret["_trainer"] = self._trainer.state_dict()
|
| 503 |
+
return ret
|
| 504 |
+
|
| 505 |
+
def load_state_dict(self, state_dict):
|
| 506 |
+
super().load_state_dict(state_dict)
|
| 507 |
+
self._trainer.load_state_dict(state_dict["_trainer"])
|
| 508 |
+
|
| 509 |
+
@classmethod
|
| 510 |
+
def build_model(cls, cfg):
|
| 511 |
+
"""
|
| 512 |
+
Returns:
|
| 513 |
+
torch.nn.Module:
|
| 514 |
+
|
| 515 |
+
It now calls :func:`detectron2.modeling.build_model`.
|
| 516 |
+
Overwrite it if you'd like a different model.
|
| 517 |
+
"""
|
| 518 |
+
model = build_model(cfg)
|
| 519 |
+
logger = logging.getLogger(__name__)
|
| 520 |
+
logger.info("Model:\n{}".format(model))
|
| 521 |
+
return model
|
| 522 |
+
|
| 523 |
+
@classmethod
|
| 524 |
+
def build_optimizer(cls, cfg, model):
|
| 525 |
+
"""
|
| 526 |
+
Returns:
|
| 527 |
+
torch.optim.Optimizer:
|
| 528 |
+
|
| 529 |
+
It now calls :func:`detectron2.solver.build_optimizer`.
|
| 530 |
+
Overwrite it if you'd like a different optimizer.
|
| 531 |
+
"""
|
| 532 |
+
return build_optimizer(cfg, model)
|
| 533 |
+
|
| 534 |
+
@classmethod
|
| 535 |
+
def build_lr_scheduler(cls, cfg, optimizer):
|
| 536 |
+
"""
|
| 537 |
+
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
| 538 |
+
Overwrite it if you'd like a different scheduler.
|
| 539 |
+
"""
|
| 540 |
+
return build_lr_scheduler(cfg, optimizer)
|
| 541 |
+
|
| 542 |
+
@classmethod
|
| 543 |
+
def build_train_loader(cls, cfg):
|
| 544 |
+
"""
|
| 545 |
+
Returns:
|
| 546 |
+
iterable
|
| 547 |
+
|
| 548 |
+
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
| 549 |
+
Overwrite it if you'd like a different data loader.
|
| 550 |
+
"""
|
| 551 |
+
return build_detection_train_loader(cfg)
|
| 552 |
+
|
| 553 |
+
@classmethod
|
| 554 |
+
def build_test_loader(cls, cfg, dataset_name):
|
| 555 |
+
"""
|
| 556 |
+
Returns:
|
| 557 |
+
iterable
|
| 558 |
+
|
| 559 |
+
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
| 560 |
+
Overwrite it if you'd like a different data loader.
|
| 561 |
+
"""
|
| 562 |
+
return build_detection_test_loader(cfg, dataset_name)
|
| 563 |
+
|
| 564 |
+
@classmethod
|
| 565 |
+
def build_evaluator(cls, cfg, dataset_name):
|
| 566 |
+
"""
|
| 567 |
+
Returns:
|
| 568 |
+
DatasetEvaluator or None
|
| 569 |
+
|
| 570 |
+
It is not implemented by default.
|
| 571 |
+
"""
|
| 572 |
+
raise NotImplementedError(
|
| 573 |
+
"""
|
| 574 |
+
If you want DefaultTrainer to automatically run evaluation,
|
| 575 |
+
please implement `build_evaluator()` in subclasses (see train_net.py for example).
|
| 576 |
+
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
|
| 577 |
+
"""
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
@classmethod
|
| 581 |
+
def test(cls, cfg, model, evaluators=None):
|
| 582 |
+
"""
|
| 583 |
+
Evaluate the given model. The given model is expected to already contain
|
| 584 |
+
weights to evaluate.
|
| 585 |
+
|
| 586 |
+
Args:
|
| 587 |
+
cfg (CfgNode):
|
| 588 |
+
model (nn.Module):
|
| 589 |
+
evaluators (list[DatasetEvaluator] or None): if None, will call
|
| 590 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
| 591 |
+
``cfg.DATASETS.TEST``.
|
| 592 |
+
|
| 593 |
+
Returns:
|
| 594 |
+
dict: a dict of result metrics
|
| 595 |
+
"""
|
| 596 |
+
logger = logging.getLogger(__name__)
|
| 597 |
+
if isinstance(evaluators, DatasetEvaluator):
|
| 598 |
+
evaluators = [evaluators]
|
| 599 |
+
if evaluators is not None:
|
| 600 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
| 601 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
results = OrderedDict()
|
| 605 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
| 606 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
| 607 |
+
# When evaluators are passed in as arguments,
|
| 608 |
+
# implicitly assume that evaluators can be created before data_loader.
|
| 609 |
+
if evaluators is not None:
|
| 610 |
+
evaluator = evaluators[idx]
|
| 611 |
+
else:
|
| 612 |
+
try:
|
| 613 |
+
evaluator = cls.build_evaluator(cfg, dataset_name)
|
| 614 |
+
except NotImplementedError:
|
| 615 |
+
logger.warn(
|
| 616 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
| 617 |
+
"or implement its `build_evaluator` method."
|
| 618 |
+
)
|
| 619 |
+
results[dataset_name] = {}
|
| 620 |
+
continue
|
| 621 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
| 622 |
+
results[dataset_name] = results_i
|
| 623 |
+
if comm.is_main_process():
|
| 624 |
+
assert isinstance(
|
| 625 |
+
results_i, dict
|
| 626 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
| 627 |
+
results_i
|
| 628 |
+
)
|
| 629 |
+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
| 630 |
+
print_csv_format(results_i)
|
| 631 |
+
|
| 632 |
+
if len(results) == 1:
|
| 633 |
+
results = list(results.values())[0]
|
| 634 |
+
return results
|
| 635 |
+
|
| 636 |
+
@staticmethod
|
| 637 |
+
def auto_scale_workers(cfg, num_workers: int):
|
| 638 |
+
"""
|
| 639 |
+
When the config is defined for certain number of workers (according to
|
| 640 |
+
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
|
| 641 |
+
workers currently in use, returns a new cfg where the total batch size
|
| 642 |
+
is scaled so that the per-GPU batch size stays the same as the
|
| 643 |
+
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
|
| 644 |
+
|
| 645 |
+
Other config options are also scaled accordingly:
|
| 646 |
+
* training steps and warmup steps are scaled inverse proportionally.
|
| 647 |
+
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
|
| 648 |
+
|
| 649 |
+
For example, with the original config like the following:
|
| 650 |
+
|
| 651 |
+
.. code-block:: yaml
|
| 652 |
+
|
| 653 |
+
IMS_PER_BATCH: 16
|
| 654 |
+
BASE_LR: 0.1
|
| 655 |
+
REFERENCE_WORLD_SIZE: 8
|
| 656 |
+
MAX_ITER: 5000
|
| 657 |
+
STEPS: (4000,)
|
| 658 |
+
CHECKPOINT_PERIOD: 1000
|
| 659 |
+
|
| 660 |
+
When this config is used on 16 GPUs instead of the reference number 8,
|
| 661 |
+
calling this method will return a new config with:
|
| 662 |
+
|
| 663 |
+
.. code-block:: yaml
|
| 664 |
+
|
| 665 |
+
IMS_PER_BATCH: 32
|
| 666 |
+
BASE_LR: 0.2
|
| 667 |
+
REFERENCE_WORLD_SIZE: 16
|
| 668 |
+
MAX_ITER: 2500
|
| 669 |
+
STEPS: (2000,)
|
| 670 |
+
CHECKPOINT_PERIOD: 500
|
| 671 |
+
|
| 672 |
+
Note that both the original config and this new config can be trained on 16 GPUs.
|
| 673 |
+
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
|
| 677 |
+
"""
|
| 678 |
+
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
|
| 679 |
+
if old_world_size == 0 or old_world_size == num_workers:
|
| 680 |
+
return cfg
|
| 681 |
+
cfg = cfg.clone()
|
| 682 |
+
frozen = cfg.is_frozen()
|
| 683 |
+
cfg.defrost()
|
| 684 |
+
|
| 685 |
+
assert (
|
| 686 |
+
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
|
| 687 |
+
), "Invalid REFERENCE_WORLD_SIZE in config!"
|
| 688 |
+
scale = num_workers / old_world_size
|
| 689 |
+
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
|
| 690 |
+
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
|
| 691 |
+
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
|
| 692 |
+
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
|
| 693 |
+
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
|
| 694 |
+
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
|
| 695 |
+
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
|
| 696 |
+
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
|
| 697 |
+
logger = logging.getLogger(__name__)
|
| 698 |
+
logger.info(
|
| 699 |
+
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
|
| 700 |
+
f"max_iter={max_iter}, warmup={warmup_iter}."
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
if frozen:
|
| 704 |
+
cfg.freeze()
|
| 705 |
+
return cfg
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
# Access basic attributes from the underlying trainer
|
| 709 |
+
for _attr in ["model", "data_loader", "optimizer"]:
|
| 710 |
+
setattr(
|
| 711 |
+
DefaultTrainer,
|
| 712 |
+
_attr,
|
| 713 |
+
property(
|
| 714 |
+
# getter
|
| 715 |
+
lambda self, x=_attr: getattr(self._trainer, x),
|
| 716 |
+
# setter
|
| 717 |
+
lambda self, value, x=_attr: setattr(self._trainer, x, value),
|
| 718 |
+
),
|
| 719 |
+
)
|
CatVTON/detectron2/engine/hooks.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
import datetime
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import operator
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
import time
|
| 12 |
+
import warnings
|
| 13 |
+
from collections import Counter
|
| 14 |
+
import torch
|
| 15 |
+
from fvcore.common.checkpoint import Checkpointer
|
| 16 |
+
from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
| 17 |
+
from fvcore.common.param_scheduler import ParamScheduler
|
| 18 |
+
from fvcore.common.timer import Timer
|
| 19 |
+
from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
|
| 20 |
+
|
| 21 |
+
import detectron2.utils.comm as comm
|
| 22 |
+
from detectron2.evaluation.testing import flatten_results_dict
|
| 23 |
+
from detectron2.solver import LRMultiplier
|
| 24 |
+
from detectron2.solver import LRScheduler as _LRScheduler
|
| 25 |
+
from detectron2.utils.events import EventStorage, EventWriter
|
| 26 |
+
from detectron2.utils.file_io import PathManager
|
| 27 |
+
|
| 28 |
+
from .train_loop import HookBase
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"CallbackHook",
|
| 32 |
+
"IterationTimer",
|
| 33 |
+
"PeriodicWriter",
|
| 34 |
+
"PeriodicCheckpointer",
|
| 35 |
+
"BestCheckpointer",
|
| 36 |
+
"LRScheduler",
|
| 37 |
+
"AutogradProfiler",
|
| 38 |
+
"EvalHook",
|
| 39 |
+
"PreciseBN",
|
| 40 |
+
"TorchProfiler",
|
| 41 |
+
"TorchMemoryStats",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
Implement some common hooks.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CallbackHook(HookBase):
|
| 51 |
+
"""
|
| 52 |
+
Create a hook using callback functions provided by the user.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
|
| 56 |
+
"""
|
| 57 |
+
Each argument is a function that takes one argument: the trainer.
|
| 58 |
+
"""
|
| 59 |
+
self._before_train = before_train
|
| 60 |
+
self._before_step = before_step
|
| 61 |
+
self._after_step = after_step
|
| 62 |
+
self._after_train = after_train
|
| 63 |
+
|
| 64 |
+
def before_train(self):
|
| 65 |
+
if self._before_train:
|
| 66 |
+
self._before_train(self.trainer)
|
| 67 |
+
|
| 68 |
+
def after_train(self):
|
| 69 |
+
if self._after_train:
|
| 70 |
+
self._after_train(self.trainer)
|
| 71 |
+
# The functions may be closures that hold reference to the trainer
|
| 72 |
+
# Therefore, delete them to avoid circular reference.
|
| 73 |
+
del self._before_train, self._after_train
|
| 74 |
+
del self._before_step, self._after_step
|
| 75 |
+
|
| 76 |
+
def before_step(self):
|
| 77 |
+
if self._before_step:
|
| 78 |
+
self._before_step(self.trainer)
|
| 79 |
+
|
| 80 |
+
def after_step(self):
|
| 81 |
+
if self._after_step:
|
| 82 |
+
self._after_step(self.trainer)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class IterationTimer(HookBase):
|
| 86 |
+
"""
|
| 87 |
+
Track the time spent for each iteration (each run_step call in the trainer).
|
| 88 |
+
Print a summary in the end of training.
|
| 89 |
+
|
| 90 |
+
This hook uses the time between the call to its :meth:`before_step`
|
| 91 |
+
and :meth:`after_step` methods.
|
| 92 |
+
Under the convention that :meth:`before_step` of all hooks should only
|
| 93 |
+
take negligible amount of time, the :class:`IterationTimer` hook should be
|
| 94 |
+
placed at the beginning of the list of hooks to obtain accurate timing.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, warmup_iter=3):
|
| 98 |
+
"""
|
| 99 |
+
Args:
|
| 100 |
+
warmup_iter (int): the number of iterations at the beginning to exclude
|
| 101 |
+
from timing.
|
| 102 |
+
"""
|
| 103 |
+
self._warmup_iter = warmup_iter
|
| 104 |
+
self._step_timer = Timer()
|
| 105 |
+
self._start_time = time.perf_counter()
|
| 106 |
+
self._total_timer = Timer()
|
| 107 |
+
|
| 108 |
+
def before_train(self):
|
| 109 |
+
self._start_time = time.perf_counter()
|
| 110 |
+
self._total_timer.reset()
|
| 111 |
+
self._total_timer.pause()
|
| 112 |
+
|
| 113 |
+
def after_train(self):
|
| 114 |
+
logger = logging.getLogger(__name__)
|
| 115 |
+
total_time = time.perf_counter() - self._start_time
|
| 116 |
+
total_time_minus_hooks = self._total_timer.seconds()
|
| 117 |
+
hook_time = total_time - total_time_minus_hooks
|
| 118 |
+
|
| 119 |
+
num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter
|
| 120 |
+
|
| 121 |
+
if num_iter > 0 and total_time_minus_hooks > 0:
|
| 122 |
+
# Speed is meaningful only after warmup
|
| 123 |
+
# NOTE this format is parsed by grep in some scripts
|
| 124 |
+
logger.info(
|
| 125 |
+
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
|
| 126 |
+
num_iter,
|
| 127 |
+
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
|
| 128 |
+
total_time_minus_hooks / num_iter,
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
logger.info(
|
| 133 |
+
"Total training time: {} ({} on hooks)".format(
|
| 134 |
+
str(datetime.timedelta(seconds=int(total_time))),
|
| 135 |
+
str(datetime.timedelta(seconds=int(hook_time))),
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def before_step(self):
|
| 140 |
+
self._step_timer.reset()
|
| 141 |
+
self._total_timer.resume()
|
| 142 |
+
|
| 143 |
+
def after_step(self):
|
| 144 |
+
# +1 because we're in after_step, the current step is done
|
| 145 |
+
# but not yet counted
|
| 146 |
+
iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1
|
| 147 |
+
if iter_done >= self._warmup_iter:
|
| 148 |
+
sec = self._step_timer.seconds()
|
| 149 |
+
self.trainer.storage.put_scalars(time=sec)
|
| 150 |
+
else:
|
| 151 |
+
self._start_time = time.perf_counter()
|
| 152 |
+
self._total_timer.reset()
|
| 153 |
+
|
| 154 |
+
self._total_timer.pause()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class PeriodicWriter(HookBase):
|
| 158 |
+
"""
|
| 159 |
+
Write events to EventStorage (by calling ``writer.write()``) periodically.
|
| 160 |
+
|
| 161 |
+
It is executed every ``period`` iterations and after the last iteration.
|
| 162 |
+
Note that ``period`` does not affect how data is smoothed by each writer.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self, writers, period=20):
|
| 166 |
+
"""
|
| 167 |
+
Args:
|
| 168 |
+
writers (list[EventWriter]): a list of EventWriter objects
|
| 169 |
+
period (int):
|
| 170 |
+
"""
|
| 171 |
+
self._writers = writers
|
| 172 |
+
for w in writers:
|
| 173 |
+
assert isinstance(w, EventWriter), w
|
| 174 |
+
self._period = period
|
| 175 |
+
|
| 176 |
+
def after_step(self):
|
| 177 |
+
if (self.trainer.iter + 1) % self._period == 0 or (
|
| 178 |
+
self.trainer.iter == self.trainer.max_iter - 1
|
| 179 |
+
):
|
| 180 |
+
for writer in self._writers:
|
| 181 |
+
writer.write()
|
| 182 |
+
|
| 183 |
+
def after_train(self):
|
| 184 |
+
for writer in self._writers:
|
| 185 |
+
# If any new data is found (e.g. produced by other after_train),
|
| 186 |
+
# write them before closing
|
| 187 |
+
writer.write()
|
| 188 |
+
writer.close()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
| 192 |
+
"""
|
| 193 |
+
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
|
| 194 |
+
|
| 195 |
+
Note that when used as a hook,
|
| 196 |
+
it is unable to save additional data other than what's defined
|
| 197 |
+
by the given `checkpointer`.
|
| 198 |
+
|
| 199 |
+
It is executed every ``period`` iterations and after the last iteration.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def before_train(self):
|
| 203 |
+
self.max_iter = self.trainer.max_iter
|
| 204 |
+
|
| 205 |
+
def after_step(self):
|
| 206 |
+
# No way to use **kwargs
|
| 207 |
+
self.step(self.trainer.iter)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class BestCheckpointer(HookBase):
|
| 211 |
+
"""
|
| 212 |
+
Checkpoints best weights based off given metric.
|
| 213 |
+
|
| 214 |
+
This hook should be used in conjunction to and executed after the hook
|
| 215 |
+
that produces the metric, e.g. `EvalHook`.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
eval_period: int,
|
| 221 |
+
checkpointer: Checkpointer,
|
| 222 |
+
val_metric: str,
|
| 223 |
+
mode: str = "max",
|
| 224 |
+
file_prefix: str = "model_best",
|
| 225 |
+
) -> None:
|
| 226 |
+
"""
|
| 227 |
+
Args:
|
| 228 |
+
eval_period (int): the period `EvalHook` is set to run.
|
| 229 |
+
checkpointer: the checkpointer object used to save checkpoints.
|
| 230 |
+
val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50"
|
| 231 |
+
mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
|
| 232 |
+
maximized or minimized, e.g. for "bbox/AP50" it should be "max"
|
| 233 |
+
file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
|
| 234 |
+
"""
|
| 235 |
+
self._logger = logging.getLogger(__name__)
|
| 236 |
+
self._period = eval_period
|
| 237 |
+
self._val_metric = val_metric
|
| 238 |
+
assert mode in [
|
| 239 |
+
"max",
|
| 240 |
+
"min",
|
| 241 |
+
], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
|
| 242 |
+
if mode == "max":
|
| 243 |
+
self._compare = operator.gt
|
| 244 |
+
else:
|
| 245 |
+
self._compare = operator.lt
|
| 246 |
+
self._checkpointer = checkpointer
|
| 247 |
+
self._file_prefix = file_prefix
|
| 248 |
+
self.best_metric = None
|
| 249 |
+
self.best_iter = None
|
| 250 |
+
|
| 251 |
+
def _update_best(self, val, iteration):
|
| 252 |
+
if math.isnan(val) or math.isinf(val):
|
| 253 |
+
return False
|
| 254 |
+
self.best_metric = val
|
| 255 |
+
self.best_iter = iteration
|
| 256 |
+
return True
|
| 257 |
+
|
| 258 |
+
def _best_checking(self):
|
| 259 |
+
metric_tuple = self.trainer.storage.latest().get(self._val_metric)
|
| 260 |
+
if metric_tuple is None:
|
| 261 |
+
self._logger.warning(
|
| 262 |
+
f"Given val metric {self._val_metric} does not seem to be computed/stored."
|
| 263 |
+
"Will not be checkpointing based on it."
|
| 264 |
+
)
|
| 265 |
+
return
|
| 266 |
+
else:
|
| 267 |
+
latest_metric, metric_iter = metric_tuple
|
| 268 |
+
|
| 269 |
+
if self.best_metric is None:
|
| 270 |
+
if self._update_best(latest_metric, metric_iter):
|
| 271 |
+
additional_state = {"iteration": metric_iter}
|
| 272 |
+
self._checkpointer.save(f"{self._file_prefix}", **additional_state)
|
| 273 |
+
self._logger.info(
|
| 274 |
+
f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
|
| 275 |
+
)
|
| 276 |
+
elif self._compare(latest_metric, self.best_metric):
|
| 277 |
+
additional_state = {"iteration": metric_iter}
|
| 278 |
+
self._checkpointer.save(f"{self._file_prefix}", **additional_state)
|
| 279 |
+
self._logger.info(
|
| 280 |
+
f"Saved best model as latest eval score for {self._val_metric} is "
|
| 281 |
+
f"{latest_metric:0.5f}, better than last best score "
|
| 282 |
+
f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
|
| 283 |
+
)
|
| 284 |
+
self._update_best(latest_metric, metric_iter)
|
| 285 |
+
else:
|
| 286 |
+
self._logger.info(
|
| 287 |
+
f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, "
|
| 288 |
+
f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def after_step(self):
|
| 292 |
+
# same conditions as `EvalHook`
|
| 293 |
+
next_iter = self.trainer.iter + 1
|
| 294 |
+
if (
|
| 295 |
+
self._period > 0
|
| 296 |
+
and next_iter % self._period == 0
|
| 297 |
+
and next_iter != self.trainer.max_iter
|
| 298 |
+
):
|
| 299 |
+
self._best_checking()
|
| 300 |
+
|
| 301 |
+
def after_train(self):
|
| 302 |
+
# same conditions as `EvalHook`
|
| 303 |
+
if self.trainer.iter + 1 >= self.trainer.max_iter:
|
| 304 |
+
self._best_checking()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class LRScheduler(HookBase):
|
| 308 |
+
"""
|
| 309 |
+
A hook which executes a torch builtin LR scheduler and summarizes the LR.
|
| 310 |
+
It is executed after every iteration.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(self, optimizer=None, scheduler=None):
|
| 314 |
+
"""
|
| 315 |
+
Args:
|
| 316 |
+
optimizer (torch.optim.Optimizer):
|
| 317 |
+
scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
|
| 318 |
+
if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
|
| 319 |
+
in the optimizer.
|
| 320 |
+
|
| 321 |
+
If any argument is not given, will try to obtain it from the trainer.
|
| 322 |
+
"""
|
| 323 |
+
self._optimizer = optimizer
|
| 324 |
+
self._scheduler = scheduler
|
| 325 |
+
|
| 326 |
+
def before_train(self):
|
| 327 |
+
self._optimizer = self._optimizer or self.trainer.optimizer
|
| 328 |
+
if isinstance(self.scheduler, ParamScheduler):
|
| 329 |
+
self._scheduler = LRMultiplier(
|
| 330 |
+
self._optimizer,
|
| 331 |
+
self.scheduler,
|
| 332 |
+
self.trainer.max_iter,
|
| 333 |
+
last_iter=self.trainer.iter - 1,
|
| 334 |
+
)
|
| 335 |
+
self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
def get_best_param_group_id(optimizer):
|
| 339 |
+
# NOTE: some heuristics on what LR to summarize
|
| 340 |
+
# summarize the param group with most parameters
|
| 341 |
+
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
|
| 342 |
+
|
| 343 |
+
if largest_group == 1:
|
| 344 |
+
# If all groups have one parameter,
|
| 345 |
+
# then find the most common initial LR, and use it for summary
|
| 346 |
+
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
|
| 347 |
+
lr = lr_count.most_common()[0][0]
|
| 348 |
+
for i, g in enumerate(optimizer.param_groups):
|
| 349 |
+
if g["lr"] == lr:
|
| 350 |
+
return i
|
| 351 |
+
else:
|
| 352 |
+
for i, g in enumerate(optimizer.param_groups):
|
| 353 |
+
if len(g["params"]) == largest_group:
|
| 354 |
+
return i
|
| 355 |
+
|
| 356 |
+
def after_step(self):
|
| 357 |
+
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
| 358 |
+
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
| 359 |
+
self.scheduler.step()
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def scheduler(self):
|
| 363 |
+
return self._scheduler or self.trainer.scheduler
|
| 364 |
+
|
| 365 |
+
def state_dict(self):
|
| 366 |
+
if isinstance(self.scheduler, _LRScheduler):
|
| 367 |
+
return self.scheduler.state_dict()
|
| 368 |
+
return {}
|
| 369 |
+
|
| 370 |
+
def load_state_dict(self, state_dict):
|
| 371 |
+
if isinstance(self.scheduler, _LRScheduler):
|
| 372 |
+
logger = logging.getLogger(__name__)
|
| 373 |
+
logger.info("Loading scheduler from state_dict ...")
|
| 374 |
+
self.scheduler.load_state_dict(state_dict)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class TorchProfiler(HookBase):
|
| 378 |
+
"""
|
| 379 |
+
A hook which runs `torch.profiler.profile`.
|
| 380 |
+
|
| 381 |
+
Examples:
|
| 382 |
+
::
|
| 383 |
+
hooks.TorchProfiler(
|
| 384 |
+
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
The above example will run the profiler for iteration 10~20 and dump
|
| 388 |
+
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
| 389 |
+
because they are typically slower than the rest.
|
| 390 |
+
The result files can be loaded in the ``chrome://tracing`` page in chrome browser,
|
| 391 |
+
and the tensorboard visualizations can be visualized using
|
| 392 |
+
``tensorboard --logdir OUTPUT_DIR/log``
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):
|
| 396 |
+
"""
|
| 397 |
+
Args:
|
| 398 |
+
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
| 399 |
+
and returns whether to enable the profiler.
|
| 400 |
+
It will be called once every step, and can be used to select which steps to profile.
|
| 401 |
+
output_dir (str): the output directory to dump tracing files.
|
| 402 |
+
activities (iterable): same as in `torch.profiler.profile`.
|
| 403 |
+
save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/
|
| 404 |
+
"""
|
| 405 |
+
self._enable_predicate = enable_predicate
|
| 406 |
+
self._activities = activities
|
| 407 |
+
self._output_dir = output_dir
|
| 408 |
+
self._save_tensorboard = save_tensorboard
|
| 409 |
+
|
| 410 |
+
def before_step(self):
|
| 411 |
+
if self._enable_predicate(self.trainer):
|
| 412 |
+
if self._save_tensorboard:
|
| 413 |
+
on_trace_ready = torch.profiler.tensorboard_trace_handler(
|
| 414 |
+
os.path.join(
|
| 415 |
+
self._output_dir,
|
| 416 |
+
"log",
|
| 417 |
+
"profiler-tensorboard-iter{}".format(self.trainer.iter),
|
| 418 |
+
),
|
| 419 |
+
f"worker{comm.get_rank()}",
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
on_trace_ready = None
|
| 423 |
+
self._profiler = torch.profiler.profile(
|
| 424 |
+
activities=self._activities,
|
| 425 |
+
on_trace_ready=on_trace_ready,
|
| 426 |
+
record_shapes=True,
|
| 427 |
+
profile_memory=True,
|
| 428 |
+
with_stack=True,
|
| 429 |
+
with_flops=True,
|
| 430 |
+
)
|
| 431 |
+
self._profiler.__enter__()
|
| 432 |
+
else:
|
| 433 |
+
self._profiler = None
|
| 434 |
+
|
| 435 |
+
def after_step(self):
|
| 436 |
+
if self._profiler is None:
|
| 437 |
+
return
|
| 438 |
+
self._profiler.__exit__(None, None, None)
|
| 439 |
+
if not self._save_tensorboard:
|
| 440 |
+
PathManager.mkdirs(self._output_dir)
|
| 441 |
+
out_file = os.path.join(
|
| 442 |
+
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
|
| 443 |
+
)
|
| 444 |
+
if "://" not in out_file:
|
| 445 |
+
self._profiler.export_chrome_trace(out_file)
|
| 446 |
+
else:
|
| 447 |
+
# Support non-posix filesystems
|
| 448 |
+
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
|
| 449 |
+
tmp_file = os.path.join(d, "tmp.json")
|
| 450 |
+
self._profiler.export_chrome_trace(tmp_file)
|
| 451 |
+
with open(tmp_file) as f:
|
| 452 |
+
content = f.read()
|
| 453 |
+
with PathManager.open(out_file, "w") as f:
|
| 454 |
+
f.write(content)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class AutogradProfiler(TorchProfiler):
|
| 458 |
+
"""
|
| 459 |
+
A hook which runs `torch.autograd.profiler.profile`.
|
| 460 |
+
|
| 461 |
+
Examples:
|
| 462 |
+
::
|
| 463 |
+
hooks.AutogradProfiler(
|
| 464 |
+
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
The above example will run the profiler for iteration 10~20 and dump
|
| 468 |
+
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
| 469 |
+
because they are typically slower than the rest.
|
| 470 |
+
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
|
| 471 |
+
|
| 472 |
+
Note:
|
| 473 |
+
When used together with NCCL on older version of GPUs,
|
| 474 |
+
autograd profiler may cause deadlock because it unnecessarily allocates
|
| 475 |
+
memory on every device it sees. The memory management calls, if
|
| 476 |
+
interleaved with NCCL calls, lead to deadlock on GPUs that do not
|
| 477 |
+
support ``cudaLaunchCooperativeKernelMultiDevice``.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
|
| 481 |
+
"""
|
| 482 |
+
Args:
|
| 483 |
+
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
| 484 |
+
and returns whether to enable the profiler.
|
| 485 |
+
It will be called once every step, and can be used to select which steps to profile.
|
| 486 |
+
output_dir (str): the output directory to dump tracing files.
|
| 487 |
+
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
|
| 488 |
+
"""
|
| 489 |
+
warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.")
|
| 490 |
+
self._enable_predicate = enable_predicate
|
| 491 |
+
self._use_cuda = use_cuda
|
| 492 |
+
self._output_dir = output_dir
|
| 493 |
+
|
| 494 |
+
def before_step(self):
|
| 495 |
+
if self._enable_predicate(self.trainer):
|
| 496 |
+
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
|
| 497 |
+
self._profiler.__enter__()
|
| 498 |
+
else:
|
| 499 |
+
self._profiler = None
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class EvalHook(HookBase):
|
| 503 |
+
"""
|
| 504 |
+
Run an evaluation function periodically, and at the end of training.
|
| 505 |
+
|
| 506 |
+
It is executed every ``eval_period`` iterations and after the last iteration.
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
def __init__(self, eval_period, eval_function, eval_after_train=True):
|
| 510 |
+
"""
|
| 511 |
+
Args:
|
| 512 |
+
eval_period (int): the period to run `eval_function`. Set to 0 to
|
| 513 |
+
not evaluate periodically (but still evaluate after the last iteration
|
| 514 |
+
if `eval_after_train` is True).
|
| 515 |
+
eval_function (callable): a function which takes no arguments, and
|
| 516 |
+
returns a nested dict of evaluation metrics.
|
| 517 |
+
eval_after_train (bool): whether to evaluate after the last iteration
|
| 518 |
+
|
| 519 |
+
Note:
|
| 520 |
+
This hook must be enabled in all or none workers.
|
| 521 |
+
If you would like only certain workers to perform evaluation,
|
| 522 |
+
give other workers a no-op function (`eval_function=lambda: None`).
|
| 523 |
+
"""
|
| 524 |
+
self._period = eval_period
|
| 525 |
+
self._func = eval_function
|
| 526 |
+
self._eval_after_train = eval_after_train
|
| 527 |
+
|
| 528 |
+
def _do_eval(self):
|
| 529 |
+
results = self._func()
|
| 530 |
+
|
| 531 |
+
if results:
|
| 532 |
+
assert isinstance(
|
| 533 |
+
results, dict
|
| 534 |
+
), "Eval function must return a dict. Got {} instead.".format(results)
|
| 535 |
+
|
| 536 |
+
flattened_results = flatten_results_dict(results)
|
| 537 |
+
for k, v in flattened_results.items():
|
| 538 |
+
try:
|
| 539 |
+
v = float(v)
|
| 540 |
+
except Exception as e:
|
| 541 |
+
raise ValueError(
|
| 542 |
+
"[EvalHook] eval_function should return a nested dict of float. "
|
| 543 |
+
"Got '{}: {}' instead.".format(k, v)
|
| 544 |
+
) from e
|
| 545 |
+
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
| 546 |
+
|
| 547 |
+
# Evaluation may take different time among workers.
|
| 548 |
+
# A barrier make them start the next iteration together.
|
| 549 |
+
comm.synchronize()
|
| 550 |
+
|
| 551 |
+
def after_step(self):
|
| 552 |
+
next_iter = self.trainer.iter + 1
|
| 553 |
+
if self._period > 0 and next_iter % self._period == 0:
|
| 554 |
+
# do the last eval in after_train
|
| 555 |
+
if next_iter != self.trainer.max_iter:
|
| 556 |
+
self._do_eval()
|
| 557 |
+
|
| 558 |
+
def after_train(self):
|
| 559 |
+
# This condition is to prevent the eval from running after a failed training
|
| 560 |
+
if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter:
|
| 561 |
+
self._do_eval()
|
| 562 |
+
# func is likely a closure that holds reference to the trainer
|
| 563 |
+
# therefore we clean it to avoid circular reference in the end
|
| 564 |
+
del self._func
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class PreciseBN(HookBase):
|
| 568 |
+
"""
|
| 569 |
+
The standard implementation of BatchNorm uses EMA in inference, which is
|
| 570 |
+
sometimes suboptimal.
|
| 571 |
+
This class computes the true average of statistics rather than the moving average,
|
| 572 |
+
and put true averages to every BN layer in the given model.
|
| 573 |
+
|
| 574 |
+
It is executed every ``period`` iterations and after the last iteration.
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
def __init__(self, period, model, data_loader, num_iter):
|
| 578 |
+
"""
|
| 579 |
+
Args:
|
| 580 |
+
period (int): the period this hook is run, or 0 to not run during training.
|
| 581 |
+
The hook will always run in the end of training.
|
| 582 |
+
model (nn.Module): a module whose all BN layers in training mode will be
|
| 583 |
+
updated by precise BN.
|
| 584 |
+
Note that user is responsible for ensuring the BN layers to be
|
| 585 |
+
updated are in training mode when this hook is triggered.
|
| 586 |
+
data_loader (iterable): it will produce data to be run by `model(data)`.
|
| 587 |
+
num_iter (int): number of iterations used to compute the precise
|
| 588 |
+
statistics.
|
| 589 |
+
"""
|
| 590 |
+
self._logger = logging.getLogger(__name__)
|
| 591 |
+
if len(get_bn_modules(model)) == 0:
|
| 592 |
+
self._logger.info(
|
| 593 |
+
"PreciseBN is disabled because model does not contain BN layers in training mode."
|
| 594 |
+
)
|
| 595 |
+
self._disabled = True
|
| 596 |
+
return
|
| 597 |
+
|
| 598 |
+
self._model = model
|
| 599 |
+
self._data_loader = data_loader
|
| 600 |
+
self._num_iter = num_iter
|
| 601 |
+
self._period = period
|
| 602 |
+
self._disabled = False
|
| 603 |
+
|
| 604 |
+
self._data_iter = None
|
| 605 |
+
|
| 606 |
+
def after_step(self):
|
| 607 |
+
next_iter = self.trainer.iter + 1
|
| 608 |
+
is_final = next_iter == self.trainer.max_iter
|
| 609 |
+
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
| 610 |
+
self.update_stats()
|
| 611 |
+
|
| 612 |
+
def update_stats(self):
|
| 613 |
+
"""
|
| 614 |
+
Update the model with precise statistics. Users can manually call this method.
|
| 615 |
+
"""
|
| 616 |
+
if self._disabled:
|
| 617 |
+
return
|
| 618 |
+
|
| 619 |
+
if self._data_iter is None:
|
| 620 |
+
self._data_iter = iter(self._data_loader)
|
| 621 |
+
|
| 622 |
+
def data_loader():
|
| 623 |
+
for num_iter in itertools.count(1):
|
| 624 |
+
if num_iter % 100 == 0:
|
| 625 |
+
self._logger.info(
|
| 626 |
+
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
| 627 |
+
)
|
| 628 |
+
# This way we can reuse the same iterator
|
| 629 |
+
yield next(self._data_iter)
|
| 630 |
+
|
| 631 |
+
with EventStorage(): # capture events in a new storage to discard them
|
| 632 |
+
self._logger.info(
|
| 633 |
+
"Running precise-BN for {} iterations... ".format(self._num_iter)
|
| 634 |
+
+ "Note that this could produce different statistics every time."
|
| 635 |
+
)
|
| 636 |
+
update_bn_stats(self._model, data_loader(), self._num_iter)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class TorchMemoryStats(HookBase):
|
| 640 |
+
"""
|
| 641 |
+
Writes pytorch's cuda memory statistics periodically.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def __init__(self, period=20, max_runs=10):
|
| 645 |
+
"""
|
| 646 |
+
Args:
|
| 647 |
+
period (int): Output stats each 'period' iterations
|
| 648 |
+
max_runs (int): Stop the logging after 'max_runs'
|
| 649 |
+
"""
|
| 650 |
+
|
| 651 |
+
self._logger = logging.getLogger(__name__)
|
| 652 |
+
self._period = period
|
| 653 |
+
self._max_runs = max_runs
|
| 654 |
+
self._runs = 0
|
| 655 |
+
|
| 656 |
+
def after_step(self):
|
| 657 |
+
if self._runs > self._max_runs:
|
| 658 |
+
return
|
| 659 |
+
|
| 660 |
+
if (self.trainer.iter + 1) % self._period == 0 or (
|
| 661 |
+
self.trainer.iter == self.trainer.max_iter - 1
|
| 662 |
+
):
|
| 663 |
+
if torch.cuda.is_available():
|
| 664 |
+
max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0
|
| 665 |
+
reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0
|
| 666 |
+
max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
| 667 |
+
allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0
|
| 668 |
+
|
| 669 |
+
self._logger.info(
|
| 670 |
+
(
|
| 671 |
+
" iter: {} "
|
| 672 |
+
" max_reserved_mem: {:.0f}MB "
|
| 673 |
+
" reserved_mem: {:.0f}MB "
|
| 674 |
+
" max_allocated_mem: {:.0f}MB "
|
| 675 |
+
" allocated_mem: {:.0f}MB "
|
| 676 |
+
).format(
|
| 677 |
+
self.trainer.iter,
|
| 678 |
+
max_reserved_mb,
|
| 679 |
+
reserved_mb,
|
| 680 |
+
max_allocated_mb,
|
| 681 |
+
allocated_mb,
|
| 682 |
+
)
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
self._runs += 1
|
| 686 |
+
if self._runs == self._max_runs:
|
| 687 |
+
mem_summary = torch.cuda.memory_summary()
|
| 688 |
+
self._logger.info("\n" + mem_summary)
|
| 689 |
+
|
| 690 |
+
torch.cuda.reset_peak_memory_stats()
|
CatVTON/detectron2/engine/launch.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
|
| 8 |
+
from detectron2.utils import comm
|
| 9 |
+
|
| 10 |
+
__all__ = ["DEFAULT_TIMEOUT", "launch"]
|
| 11 |
+
|
| 12 |
+
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _find_free_port():
|
| 16 |
+
import socket
|
| 17 |
+
|
| 18 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 19 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
| 20 |
+
sock.bind(("", 0))
|
| 21 |
+
port = sock.getsockname()[1]
|
| 22 |
+
sock.close()
|
| 23 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
| 24 |
+
return port
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def launch(
|
| 28 |
+
main_func,
|
| 29 |
+
# Should be num_processes_per_machine, but kept for compatibility.
|
| 30 |
+
num_gpus_per_machine,
|
| 31 |
+
num_machines=1,
|
| 32 |
+
machine_rank=0,
|
| 33 |
+
dist_url=None,
|
| 34 |
+
args=(),
|
| 35 |
+
timeout=DEFAULT_TIMEOUT,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Launch multi-process or distributed training.
|
| 39 |
+
This function must be called on all machines involved in the training.
|
| 40 |
+
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
main_func: a function that will be called by `main_func(*args)`
|
| 44 |
+
num_gpus_per_machine (int): number of processes per machine. When
|
| 45 |
+
using GPUs, this should be the number of GPUs.
|
| 46 |
+
num_machines (int): the total number of machines
|
| 47 |
+
machine_rank (int): the rank of this machine
|
| 48 |
+
dist_url (str): url to connect to for distributed jobs, including protocol
|
| 49 |
+
e.g. "tcp://127.0.0.1:8686".
|
| 50 |
+
Can be set to "auto" to automatically select a free port on localhost
|
| 51 |
+
timeout (timedelta): timeout of the distributed workers
|
| 52 |
+
args (tuple): arguments passed to main_func
|
| 53 |
+
"""
|
| 54 |
+
world_size = num_machines * num_gpus_per_machine
|
| 55 |
+
if world_size > 1:
|
| 56 |
+
# https://github.com/pytorch/pytorch/pull/14391
|
| 57 |
+
# TODO prctl in spawned processes
|
| 58 |
+
|
| 59 |
+
if dist_url == "auto":
|
| 60 |
+
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
|
| 61 |
+
port = _find_free_port()
|
| 62 |
+
dist_url = f"tcp://127.0.0.1:{port}"
|
| 63 |
+
if num_machines > 1 and dist_url.startswith("file://"):
|
| 64 |
+
logger = logging.getLogger(__name__)
|
| 65 |
+
logger.warning(
|
| 66 |
+
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
mp.start_processes(
|
| 70 |
+
_distributed_worker,
|
| 71 |
+
nprocs=num_gpus_per_machine,
|
| 72 |
+
args=(
|
| 73 |
+
main_func,
|
| 74 |
+
world_size,
|
| 75 |
+
num_gpus_per_machine,
|
| 76 |
+
machine_rank,
|
| 77 |
+
dist_url,
|
| 78 |
+
args,
|
| 79 |
+
timeout,
|
| 80 |
+
),
|
| 81 |
+
daemon=False,
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
main_func(*args)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _distributed_worker(
|
| 88 |
+
local_rank,
|
| 89 |
+
main_func,
|
| 90 |
+
world_size,
|
| 91 |
+
num_gpus_per_machine,
|
| 92 |
+
machine_rank,
|
| 93 |
+
dist_url,
|
| 94 |
+
args,
|
| 95 |
+
timeout=DEFAULT_TIMEOUT,
|
| 96 |
+
):
|
| 97 |
+
has_gpu = torch.cuda.is_available()
|
| 98 |
+
if has_gpu:
|
| 99 |
+
assert num_gpus_per_machine <= torch.cuda.device_count()
|
| 100 |
+
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
| 101 |
+
try:
|
| 102 |
+
dist.init_process_group(
|
| 103 |
+
backend="NCCL" if has_gpu else "GLOO",
|
| 104 |
+
init_method=dist_url,
|
| 105 |
+
world_size=world_size,
|
| 106 |
+
rank=global_rank,
|
| 107 |
+
timeout=timeout,
|
| 108 |
+
)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger = logging.getLogger(__name__)
|
| 111 |
+
logger.error("Process group URL: {}".format(dist_url))
|
| 112 |
+
raise e
|
| 113 |
+
|
| 114 |
+
# Setup the local process group.
|
| 115 |
+
comm.create_local_process_group(num_gpus_per_machine)
|
| 116 |
+
if has_gpu:
|
| 117 |
+
torch.cuda.set_device(local_rank)
|
| 118 |
+
|
| 119 |
+
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
| 120 |
+
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
| 121 |
+
comm.synchronize()
|
| 122 |
+
|
| 123 |
+
main_func(*args)
|
CatVTON/detectron2/engine/train_loop.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
import concurrent.futures
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import weakref
|
| 8 |
+
from typing import List, Mapping, Optional
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
| 11 |
+
|
| 12 |
+
import detectron2.utils.comm as comm
|
| 13 |
+
from detectron2.utils.events import EventStorage, get_event_storage
|
| 14 |
+
from detectron2.utils.logger import _log_api_usage
|
| 15 |
+
|
| 16 |
+
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HookBase:
|
| 20 |
+
"""
|
| 21 |
+
Base class for hooks that can be registered with :class:`TrainerBase`.
|
| 22 |
+
|
| 23 |
+
Each hook can implement 4 methods. The way they are called is demonstrated
|
| 24 |
+
in the following snippet:
|
| 25 |
+
::
|
| 26 |
+
hook.before_train()
|
| 27 |
+
for iter in range(start_iter, max_iter):
|
| 28 |
+
hook.before_step()
|
| 29 |
+
trainer.run_step()
|
| 30 |
+
hook.after_step()
|
| 31 |
+
iter += 1
|
| 32 |
+
hook.after_train()
|
| 33 |
+
|
| 34 |
+
Notes:
|
| 35 |
+
1. In the hook method, users can access ``self.trainer`` to access more
|
| 36 |
+
properties about the context (e.g., model, current iteration, or config
|
| 37 |
+
if using :class:`DefaultTrainer`).
|
| 38 |
+
|
| 39 |
+
2. A hook that does something in :meth:`before_step` can often be
|
| 40 |
+
implemented equivalently in :meth:`after_step`.
|
| 41 |
+
If the hook takes non-trivial time, it is strongly recommended to
|
| 42 |
+
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
|
| 43 |
+
The convention is that :meth:`before_step` should only take negligible time.
|
| 44 |
+
|
| 45 |
+
Following this convention will allow hooks that do care about the difference
|
| 46 |
+
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
|
| 47 |
+
function properly.
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
trainer: "TrainerBase" = None
|
| 52 |
+
"""
|
| 53 |
+
A weak reference to the trainer object. Set by the trainer when the hook is registered.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def before_train(self):
|
| 57 |
+
"""
|
| 58 |
+
Called before the first iteration.
|
| 59 |
+
"""
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
def after_train(self):
|
| 63 |
+
"""
|
| 64 |
+
Called after the last iteration.
|
| 65 |
+
"""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def before_step(self):
|
| 69 |
+
"""
|
| 70 |
+
Called before each iteration.
|
| 71 |
+
"""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
def after_backward(self):
|
| 75 |
+
"""
|
| 76 |
+
Called after the backward pass of each iteration.
|
| 77 |
+
"""
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
def after_step(self):
|
| 81 |
+
"""
|
| 82 |
+
Called after each iteration.
|
| 83 |
+
"""
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
def state_dict(self):
|
| 87 |
+
"""
|
| 88 |
+
Hooks are stateless by default, but can be made checkpointable by
|
| 89 |
+
implementing `state_dict` and `load_state_dict`.
|
| 90 |
+
"""
|
| 91 |
+
return {}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class TrainerBase:
|
| 95 |
+
"""
|
| 96 |
+
Base class for iterative trainer with hooks.
|
| 97 |
+
|
| 98 |
+
The only assumption we made here is: the training runs in a loop.
|
| 99 |
+
A subclass can implement what the loop is.
|
| 100 |
+
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
| 101 |
+
|
| 102 |
+
Attributes:
|
| 103 |
+
iter(int): the current iteration.
|
| 104 |
+
|
| 105 |
+
start_iter(int): The iteration to start with.
|
| 106 |
+
By convention the minimum possible value is 0.
|
| 107 |
+
|
| 108 |
+
max_iter(int): The iteration to end training.
|
| 109 |
+
|
| 110 |
+
storage(EventStorage): An EventStorage that's opened during the course of training.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self) -> None:
|
| 114 |
+
self._hooks: List[HookBase] = []
|
| 115 |
+
self.iter: int = 0
|
| 116 |
+
self.start_iter: int = 0
|
| 117 |
+
self.max_iter: int
|
| 118 |
+
self.storage: EventStorage
|
| 119 |
+
_log_api_usage("trainer." + self.__class__.__name__)
|
| 120 |
+
|
| 121 |
+
def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
|
| 122 |
+
"""
|
| 123 |
+
Register hooks to the trainer. The hooks are executed in the order
|
| 124 |
+
they are registered.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
hooks (list[Optional[HookBase]]): list of hooks
|
| 128 |
+
"""
|
| 129 |
+
hooks = [h for h in hooks if h is not None]
|
| 130 |
+
for h in hooks:
|
| 131 |
+
assert isinstance(h, HookBase)
|
| 132 |
+
# To avoid circular reference, hooks and trainer cannot own each other.
|
| 133 |
+
# This normally does not matter, but will cause memory leak if the
|
| 134 |
+
# involved objects contain __del__:
|
| 135 |
+
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
| 136 |
+
h.trainer = weakref.proxy(self)
|
| 137 |
+
self._hooks.extend(hooks)
|
| 138 |
+
|
| 139 |
+
def train(self, start_iter: int, max_iter: int):
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
start_iter, max_iter (int): See docs above
|
| 143 |
+
"""
|
| 144 |
+
logger = logging.getLogger(__name__)
|
| 145 |
+
logger.info("Starting training from iteration {}".format(start_iter))
|
| 146 |
+
|
| 147 |
+
self.iter = self.start_iter = start_iter
|
| 148 |
+
self.max_iter = max_iter
|
| 149 |
+
|
| 150 |
+
with EventStorage(start_iter) as self.storage:
|
| 151 |
+
try:
|
| 152 |
+
self.before_train()
|
| 153 |
+
for self.iter in range(start_iter, max_iter):
|
| 154 |
+
self.before_step()
|
| 155 |
+
self.run_step()
|
| 156 |
+
self.after_step()
|
| 157 |
+
# self.iter == max_iter can be used by `after_train` to
|
| 158 |
+
# tell whether the training successfully finished or failed
|
| 159 |
+
# due to exceptions.
|
| 160 |
+
self.iter += 1
|
| 161 |
+
except Exception:
|
| 162 |
+
logger.exception("Exception during training:")
|
| 163 |
+
raise
|
| 164 |
+
finally:
|
| 165 |
+
self.after_train()
|
| 166 |
+
|
| 167 |
+
def before_train(self):
|
| 168 |
+
for h in self._hooks:
|
| 169 |
+
h.before_train()
|
| 170 |
+
|
| 171 |
+
def after_train(self):
|
| 172 |
+
self.storage.iter = self.iter
|
| 173 |
+
for h in self._hooks:
|
| 174 |
+
h.after_train()
|
| 175 |
+
|
| 176 |
+
def before_step(self):
|
| 177 |
+
# Maintain the invariant that storage.iter == trainer.iter
|
| 178 |
+
# for the entire execution of each step
|
| 179 |
+
self.storage.iter = self.iter
|
| 180 |
+
|
| 181 |
+
for h in self._hooks:
|
| 182 |
+
h.before_step()
|
| 183 |
+
|
| 184 |
+
def after_backward(self):
|
| 185 |
+
for h in self._hooks:
|
| 186 |
+
h.after_backward()
|
| 187 |
+
|
| 188 |
+
def after_step(self):
|
| 189 |
+
for h in self._hooks:
|
| 190 |
+
h.after_step()
|
| 191 |
+
|
| 192 |
+
def run_step(self):
|
| 193 |
+
raise NotImplementedError
|
| 194 |
+
|
| 195 |
+
def state_dict(self):
|
| 196 |
+
ret = {"iteration": self.iter}
|
| 197 |
+
hooks_state = {}
|
| 198 |
+
for h in self._hooks:
|
| 199 |
+
sd = h.state_dict()
|
| 200 |
+
if sd:
|
| 201 |
+
name = type(h).__qualname__
|
| 202 |
+
if name in hooks_state:
|
| 203 |
+
# TODO handle repetitive stateful hooks
|
| 204 |
+
continue
|
| 205 |
+
hooks_state[name] = sd
|
| 206 |
+
if hooks_state:
|
| 207 |
+
ret["hooks"] = hooks_state
|
| 208 |
+
return ret
|
| 209 |
+
|
| 210 |
+
def load_state_dict(self, state_dict):
|
| 211 |
+
logger = logging.getLogger(__name__)
|
| 212 |
+
self.iter = state_dict["iteration"]
|
| 213 |
+
for key, value in state_dict.get("hooks", {}).items():
|
| 214 |
+
for h in self._hooks:
|
| 215 |
+
try:
|
| 216 |
+
name = type(h).__qualname__
|
| 217 |
+
except AttributeError:
|
| 218 |
+
continue
|
| 219 |
+
if name == key:
|
| 220 |
+
h.load_state_dict(value)
|
| 221 |
+
break
|
| 222 |
+
else:
|
| 223 |
+
logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class SimpleTrainer(TrainerBase):
|
| 227 |
+
"""
|
| 228 |
+
A simple trainer for the most common type of task:
|
| 229 |
+
single-cost single-optimizer single-data-source iterative optimization,
|
| 230 |
+
optionally using data-parallelism.
|
| 231 |
+
It assumes that every step, you:
|
| 232 |
+
|
| 233 |
+
1. Compute the loss with a data from the data_loader.
|
| 234 |
+
2. Compute the gradients with the above loss.
|
| 235 |
+
3. Update the model with the optimizer.
|
| 236 |
+
|
| 237 |
+
All other tasks during training (checkpointing, logging, evaluation, LR schedule)
|
| 238 |
+
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
|
| 239 |
+
|
| 240 |
+
If you want to do anything fancier than this,
|
| 241 |
+
either subclass TrainerBase and implement your own `run_step`,
|
| 242 |
+
or write your own training loop.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
model,
|
| 248 |
+
data_loader,
|
| 249 |
+
optimizer,
|
| 250 |
+
gather_metric_period=1,
|
| 251 |
+
zero_grad_before_forward=False,
|
| 252 |
+
async_write_metrics=False,
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Args:
|
| 256 |
+
model: a torch Module. Takes a data from data_loader and returns a
|
| 257 |
+
dict of losses.
|
| 258 |
+
data_loader: an iterable. Contains data to be used to call model.
|
| 259 |
+
optimizer: a torch optimizer.
|
| 260 |
+
gather_metric_period: an int. Every gather_metric_period iterations
|
| 261 |
+
the metrics are gathered from all the ranks to rank 0 and logged.
|
| 262 |
+
zero_grad_before_forward: whether to zero the gradients before the forward.
|
| 263 |
+
async_write_metrics: bool. If True, then write metrics asynchronously to improve
|
| 264 |
+
training speed
|
| 265 |
+
"""
|
| 266 |
+
super().__init__()
|
| 267 |
+
|
| 268 |
+
"""
|
| 269 |
+
We set the model to training mode in the trainer.
|
| 270 |
+
However it's valid to train a model that's in eval mode.
|
| 271 |
+
If you want your model (or a submodule of it) to behave
|
| 272 |
+
like evaluation during training, you can overwrite its train() method.
|
| 273 |
+
"""
|
| 274 |
+
model.train()
|
| 275 |
+
|
| 276 |
+
self.model = model
|
| 277 |
+
self.data_loader = data_loader
|
| 278 |
+
# to access the data loader iterator, call `self._data_loader_iter`
|
| 279 |
+
self._data_loader_iter_obj = None
|
| 280 |
+
self.optimizer = optimizer
|
| 281 |
+
self.gather_metric_period = gather_metric_period
|
| 282 |
+
self.zero_grad_before_forward = zero_grad_before_forward
|
| 283 |
+
self.async_write_metrics = async_write_metrics
|
| 284 |
+
# create a thread pool that can execute non critical logic in run_step asynchronically
|
| 285 |
+
# use only 1 worker so tasks will be executred in order of submitting.
|
| 286 |
+
self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
| 287 |
+
|
| 288 |
+
def run_step(self):
|
| 289 |
+
"""
|
| 290 |
+
Implement the standard training logic described above.
|
| 291 |
+
"""
|
| 292 |
+
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
| 293 |
+
start = time.perf_counter()
|
| 294 |
+
"""
|
| 295 |
+
If you want to do something with the data, you can wrap the dataloader.
|
| 296 |
+
"""
|
| 297 |
+
data = next(self._data_loader_iter)
|
| 298 |
+
data_time = time.perf_counter() - start
|
| 299 |
+
|
| 300 |
+
if self.zero_grad_before_forward:
|
| 301 |
+
"""
|
| 302 |
+
If you need to accumulate gradients or do something similar, you can
|
| 303 |
+
wrap the optimizer with your custom `zero_grad()` method.
|
| 304 |
+
"""
|
| 305 |
+
self.optimizer.zero_grad()
|
| 306 |
+
|
| 307 |
+
"""
|
| 308 |
+
If you want to do something with the losses, you can wrap the model.
|
| 309 |
+
"""
|
| 310 |
+
loss_dict = self.model(data)
|
| 311 |
+
if isinstance(loss_dict, torch.Tensor):
|
| 312 |
+
losses = loss_dict
|
| 313 |
+
loss_dict = {"total_loss": loss_dict}
|
| 314 |
+
else:
|
| 315 |
+
losses = sum(loss_dict.values())
|
| 316 |
+
if not self.zero_grad_before_forward:
|
| 317 |
+
"""
|
| 318 |
+
If you need to accumulate gradients or do something similar, you can
|
| 319 |
+
wrap the optimizer with your custom `zero_grad()` method.
|
| 320 |
+
"""
|
| 321 |
+
self.optimizer.zero_grad()
|
| 322 |
+
losses.backward()
|
| 323 |
+
|
| 324 |
+
self.after_backward()
|
| 325 |
+
|
| 326 |
+
if self.async_write_metrics:
|
| 327 |
+
# write metrics asynchronically
|
| 328 |
+
self.concurrent_executor.submit(
|
| 329 |
+
self._write_metrics, loss_dict, data_time, iter=self.iter
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
self._write_metrics(loss_dict, data_time)
|
| 333 |
+
|
| 334 |
+
"""
|
| 335 |
+
If you need gradient clipping/scaling or other processing, you can
|
| 336 |
+
wrap the optimizer with your custom `step()` method. But it is
|
| 337 |
+
suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
|
| 338 |
+
"""
|
| 339 |
+
self.optimizer.step()
|
| 340 |
+
|
| 341 |
+
@property
|
| 342 |
+
def _data_loader_iter(self):
|
| 343 |
+
# only create the data loader iterator when it is used
|
| 344 |
+
if self._data_loader_iter_obj is None:
|
| 345 |
+
self._data_loader_iter_obj = iter(self.data_loader)
|
| 346 |
+
return self._data_loader_iter_obj
|
| 347 |
+
|
| 348 |
+
def reset_data_loader(self, data_loader_builder):
|
| 349 |
+
"""
|
| 350 |
+
Delete and replace the current data loader with a new one, which will be created
|
| 351 |
+
by calling `data_loader_builder` (without argument).
|
| 352 |
+
"""
|
| 353 |
+
del self.data_loader
|
| 354 |
+
data_loader = data_loader_builder()
|
| 355 |
+
self.data_loader = data_loader
|
| 356 |
+
self._data_loader_iter_obj = None
|
| 357 |
+
|
| 358 |
+
def _write_metrics(
|
| 359 |
+
self,
|
| 360 |
+
loss_dict: Mapping[str, torch.Tensor],
|
| 361 |
+
data_time: float,
|
| 362 |
+
prefix: str = "",
|
| 363 |
+
iter: Optional[int] = None,
|
| 364 |
+
) -> None:
|
| 365 |
+
logger = logging.getLogger(__name__)
|
| 366 |
+
|
| 367 |
+
iter = self.iter if iter is None else iter
|
| 368 |
+
if (iter + 1) % self.gather_metric_period == 0:
|
| 369 |
+
try:
|
| 370 |
+
SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix)
|
| 371 |
+
except Exception:
|
| 372 |
+
logger.exception("Exception in writing metrics: ")
|
| 373 |
+
raise
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def write_metrics(
|
| 377 |
+
loss_dict: Mapping[str, torch.Tensor],
|
| 378 |
+
data_time: float,
|
| 379 |
+
cur_iter: int,
|
| 380 |
+
prefix: str = "",
|
| 381 |
+
) -> None:
|
| 382 |
+
"""
|
| 383 |
+
Args:
|
| 384 |
+
loss_dict (dict): dict of scalar losses
|
| 385 |
+
data_time (float): time taken by the dataloader iteration
|
| 386 |
+
prefix (str): prefix for logging keys
|
| 387 |
+
"""
|
| 388 |
+
metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
|
| 389 |
+
metrics_dict["data_time"] = data_time
|
| 390 |
+
|
| 391 |
+
storage = get_event_storage()
|
| 392 |
+
# Keep track of data time per rank
|
| 393 |
+
storage.put_scalar("rank_data_time", data_time, cur_iter=cur_iter)
|
| 394 |
+
|
| 395 |
+
# Gather metrics among all workers for logging
|
| 396 |
+
# This assumes we do DDP-style training, which is currently the only
|
| 397 |
+
# supported method in detectron2.
|
| 398 |
+
all_metrics_dict = comm.gather(metrics_dict)
|
| 399 |
+
|
| 400 |
+
if comm.is_main_process():
|
| 401 |
+
# data_time among workers can have high variance. The actual latency
|
| 402 |
+
# caused by data_time is the maximum among workers.
|
| 403 |
+
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
| 404 |
+
storage.put_scalar("data_time", data_time, cur_iter=cur_iter)
|
| 405 |
+
|
| 406 |
+
# average the rest metrics
|
| 407 |
+
metrics_dict = {
|
| 408 |
+
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
| 409 |
+
}
|
| 410 |
+
total_losses_reduced = sum(metrics_dict.values())
|
| 411 |
+
if not np.isfinite(total_losses_reduced):
|
| 412 |
+
raise FloatingPointError(
|
| 413 |
+
f"Loss became infinite or NaN at iteration={cur_iter}!\n"
|
| 414 |
+
f"loss_dict = {metrics_dict}"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
storage.put_scalar(
|
| 418 |
+
"{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter
|
| 419 |
+
)
|
| 420 |
+
if len(metrics_dict) > 1:
|
| 421 |
+
storage.put_scalars(cur_iter=cur_iter, **metrics_dict)
|
| 422 |
+
|
| 423 |
+
def state_dict(self):
|
| 424 |
+
ret = super().state_dict()
|
| 425 |
+
ret["optimizer"] = self.optimizer.state_dict()
|
| 426 |
+
return ret
|
| 427 |
+
|
| 428 |
+
def load_state_dict(self, state_dict):
|
| 429 |
+
super().load_state_dict(state_dict)
|
| 430 |
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
| 431 |
+
|
| 432 |
+
def after_train(self):
|
| 433 |
+
super().after_train()
|
| 434 |
+
self.concurrent_executor.shutdown(wait=True)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class AMPTrainer(SimpleTrainer):
|
| 438 |
+
"""
|
| 439 |
+
Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision
|
| 440 |
+
in the training loop.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
def __init__(
|
| 444 |
+
self,
|
| 445 |
+
model,
|
| 446 |
+
data_loader,
|
| 447 |
+
optimizer,
|
| 448 |
+
gather_metric_period=1,
|
| 449 |
+
zero_grad_before_forward=False,
|
| 450 |
+
grad_scaler=None,
|
| 451 |
+
precision: torch.dtype = torch.float16,
|
| 452 |
+
log_grad_scaler: bool = False,
|
| 453 |
+
async_write_metrics=False,
|
| 454 |
+
):
|
| 455 |
+
"""
|
| 456 |
+
Args:
|
| 457 |
+
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward,
|
| 458 |
+
async_write_metrics: same as in :class:`SimpleTrainer`.
|
| 459 |
+
grad_scaler: torch GradScaler to automatically scale gradients.
|
| 460 |
+
precision: torch.dtype as the target precision to cast to in computations
|
| 461 |
+
"""
|
| 462 |
+
unsupported = "AMPTrainer does not support single-process multi-device training!"
|
| 463 |
+
if isinstance(model, DistributedDataParallel):
|
| 464 |
+
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
|
| 465 |
+
assert not isinstance(model, DataParallel), unsupported
|
| 466 |
+
|
| 467 |
+
super().__init__(
|
| 468 |
+
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
if grad_scaler is None:
|
| 472 |
+
from torch.cuda.amp import GradScaler
|
| 473 |
+
|
| 474 |
+
grad_scaler = GradScaler()
|
| 475 |
+
self.grad_scaler = grad_scaler
|
| 476 |
+
self.precision = precision
|
| 477 |
+
self.log_grad_scaler = log_grad_scaler
|
| 478 |
+
|
| 479 |
+
def run_step(self):
|
| 480 |
+
"""
|
| 481 |
+
Implement the AMP training logic.
|
| 482 |
+
"""
|
| 483 |
+
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
|
| 484 |
+
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
|
| 485 |
+
from torch.cuda.amp import autocast
|
| 486 |
+
|
| 487 |
+
start = time.perf_counter()
|
| 488 |
+
data = next(self._data_loader_iter)
|
| 489 |
+
data_time = time.perf_counter() - start
|
| 490 |
+
|
| 491 |
+
if self.zero_grad_before_forward:
|
| 492 |
+
self.optimizer.zero_grad()
|
| 493 |
+
with autocast(dtype=self.precision):
|
| 494 |
+
loss_dict = self.model(data)
|
| 495 |
+
if isinstance(loss_dict, torch.Tensor):
|
| 496 |
+
losses = loss_dict
|
| 497 |
+
loss_dict = {"total_loss": loss_dict}
|
| 498 |
+
else:
|
| 499 |
+
losses = sum(loss_dict.values())
|
| 500 |
+
|
| 501 |
+
if not self.zero_grad_before_forward:
|
| 502 |
+
self.optimizer.zero_grad()
|
| 503 |
+
|
| 504 |
+
self.grad_scaler.scale(losses).backward()
|
| 505 |
+
|
| 506 |
+
if self.log_grad_scaler:
|
| 507 |
+
storage = get_event_storage()
|
| 508 |
+
storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale())
|
| 509 |
+
|
| 510 |
+
self.after_backward()
|
| 511 |
+
|
| 512 |
+
if self.async_write_metrics:
|
| 513 |
+
# write metrics asynchronically
|
| 514 |
+
self.concurrent_executor.submit(
|
| 515 |
+
self._write_metrics, loss_dict, data_time, iter=self.iter
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
self._write_metrics(loss_dict, data_time)
|
| 519 |
+
|
| 520 |
+
self.grad_scaler.step(self.optimizer)
|
| 521 |
+
self.grad_scaler.update()
|
| 522 |
+
|
| 523 |
+
def state_dict(self):
|
| 524 |
+
ret = super().state_dict()
|
| 525 |
+
ret["grad_scaler"] = self.grad_scaler.state_dict()
|
| 526 |
+
return ret
|
| 527 |
+
|
| 528 |
+
def load_state_dict(self, state_dict):
|
| 529 |
+
super().load_state_dict(state_dict)
|
| 530 |
+
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
CatVTON/detectron2/modeling/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from detectron2.layers import ShapeSpec
|
| 3 |
+
|
| 4 |
+
from .anchor_generator import build_anchor_generator, ANCHOR_GENERATOR_REGISTRY
|
| 5 |
+
from .backbone import (
|
| 6 |
+
BACKBONE_REGISTRY,
|
| 7 |
+
FPN,
|
| 8 |
+
Backbone,
|
| 9 |
+
ResNet,
|
| 10 |
+
ResNetBlockBase,
|
| 11 |
+
build_backbone,
|
| 12 |
+
build_resnet_backbone,
|
| 13 |
+
make_stage,
|
| 14 |
+
ViT,
|
| 15 |
+
SimpleFeaturePyramid,
|
| 16 |
+
get_vit_lr_decay_rate,
|
| 17 |
+
MViT,
|
| 18 |
+
SwinTransformer,
|
| 19 |
+
)
|
| 20 |
+
from .meta_arch import (
|
| 21 |
+
META_ARCH_REGISTRY,
|
| 22 |
+
SEM_SEG_HEADS_REGISTRY,
|
| 23 |
+
GeneralizedRCNN,
|
| 24 |
+
PanopticFPN,
|
| 25 |
+
ProposalNetwork,
|
| 26 |
+
RetinaNet,
|
| 27 |
+
SemanticSegmentor,
|
| 28 |
+
build_model,
|
| 29 |
+
build_sem_seg_head,
|
| 30 |
+
FCOS,
|
| 31 |
+
)
|
| 32 |
+
from .postprocessing import detector_postprocess
|
| 33 |
+
from .proposal_generator import (
|
| 34 |
+
PROPOSAL_GENERATOR_REGISTRY,
|
| 35 |
+
build_proposal_generator,
|
| 36 |
+
RPN_HEAD_REGISTRY,
|
| 37 |
+
build_rpn_head,
|
| 38 |
+
)
|
| 39 |
+
from .roi_heads import (
|
| 40 |
+
ROI_BOX_HEAD_REGISTRY,
|
| 41 |
+
ROI_HEADS_REGISTRY,
|
| 42 |
+
ROI_KEYPOINT_HEAD_REGISTRY,
|
| 43 |
+
ROI_MASK_HEAD_REGISTRY,
|
| 44 |
+
ROIHeads,
|
| 45 |
+
StandardROIHeads,
|
| 46 |
+
BaseMaskRCNNHead,
|
| 47 |
+
BaseKeypointRCNNHead,
|
| 48 |
+
FastRCNNOutputLayers,
|
| 49 |
+
build_box_head,
|
| 50 |
+
build_keypoint_head,
|
| 51 |
+
build_mask_head,
|
| 52 |
+
build_roi_heads,
|
| 53 |
+
)
|
| 54 |
+
from .test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA
|
| 55 |
+
from .mmdet_wrapper import MMDetBackbone, MMDetDetector
|
| 56 |
+
|
| 57 |
+
_EXCLUDE = {"ShapeSpec"}
|
| 58 |
+
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
from detectron2.utils.env import fixup_module_metadata
|
| 62 |
+
|
| 63 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
| 64 |
+
del fixup_module_metadata
|
CatVTON/detectron2/modeling/anchor_generator.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import collections
|
| 3 |
+
import math
|
| 4 |
+
from typing import List
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from detectron2.config import configurable
|
| 9 |
+
from detectron2.layers import ShapeSpec, move_device_like
|
| 10 |
+
from detectron2.structures import Boxes, RotatedBoxes
|
| 11 |
+
from detectron2.utils.registry import Registry
|
| 12 |
+
|
| 13 |
+
ANCHOR_GENERATOR_REGISTRY = Registry("ANCHOR_GENERATOR")
|
| 14 |
+
ANCHOR_GENERATOR_REGISTRY.__doc__ = """
|
| 15 |
+
Registry for modules that creates object detection anchors for feature maps.
|
| 16 |
+
|
| 17 |
+
The registered object will be called with `obj(cfg, input_shape)`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BufferList(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Similar to nn.ParameterList, but for buffers
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, buffers):
|
| 27 |
+
super().__init__()
|
| 28 |
+
for i, buffer in enumerate(buffers):
|
| 29 |
+
# Use non-persistent buffer so the values are not saved in checkpoint
|
| 30 |
+
self.register_buffer(str(i), buffer, persistent=False)
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self._buffers)
|
| 34 |
+
|
| 35 |
+
def __iter__(self):
|
| 36 |
+
return iter(self._buffers.values())
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _create_grid_offsets(
|
| 40 |
+
size: List[int], stride: int, offset: float, target_device_tensor: torch.Tensor
|
| 41 |
+
):
|
| 42 |
+
grid_height, grid_width = size
|
| 43 |
+
shifts_x = move_device_like(
|
| 44 |
+
torch.arange(offset * stride, grid_width * stride, step=stride, dtype=torch.float32),
|
| 45 |
+
target_device_tensor,
|
| 46 |
+
)
|
| 47 |
+
shifts_y = move_device_like(
|
| 48 |
+
torch.arange(offset * stride, grid_height * stride, step=stride, dtype=torch.float32),
|
| 49 |
+
target_device_tensor,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
|
| 53 |
+
shift_x = shift_x.reshape(-1)
|
| 54 |
+
shift_y = shift_y.reshape(-1)
|
| 55 |
+
return shift_x, shift_y
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _broadcast_params(params, num_features, name):
|
| 59 |
+
"""
|
| 60 |
+
If one size (or aspect ratio) is specified and there are multiple feature
|
| 61 |
+
maps, we "broadcast" anchors of that single size (or aspect ratio)
|
| 62 |
+
over all feature maps.
|
| 63 |
+
|
| 64 |
+
If params is list[float], or list[list[float]] with len(params) == 1, repeat
|
| 65 |
+
it num_features time.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
list[list[float]]: param for each feature
|
| 69 |
+
"""
|
| 70 |
+
assert isinstance(
|
| 71 |
+
params, collections.abc.Sequence
|
| 72 |
+
), f"{name} in anchor generator has to be a list! Got {params}."
|
| 73 |
+
assert len(params), f"{name} in anchor generator cannot be empty!"
|
| 74 |
+
if not isinstance(params[0], collections.abc.Sequence): # params is list[float]
|
| 75 |
+
return [params] * num_features
|
| 76 |
+
if len(params) == 1:
|
| 77 |
+
return list(params) * num_features
|
| 78 |
+
assert len(params) == num_features, (
|
| 79 |
+
f"Got {name} of length {len(params)} in anchor generator, "
|
| 80 |
+
f"but the number of input features is {num_features}!"
|
| 81 |
+
)
|
| 82 |
+
return params
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@ANCHOR_GENERATOR_REGISTRY.register()
|
| 86 |
+
class DefaultAnchorGenerator(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Compute anchors in the standard ways described in
|
| 89 |
+
"Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks".
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
box_dim: torch.jit.Final[int] = 4
|
| 93 |
+
"""
|
| 94 |
+
the dimension of each anchor box.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
@configurable
|
| 98 |
+
def __init__(self, *, sizes, aspect_ratios, strides, offset=0.5):
|
| 99 |
+
"""
|
| 100 |
+
This interface is experimental.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
sizes (list[list[float]] or list[float]):
|
| 104 |
+
If ``sizes`` is list[list[float]], ``sizes[i]`` is the list of anchor sizes
|
| 105 |
+
(i.e. sqrt of anchor area) to use for the i-th feature map.
|
| 106 |
+
If ``sizes`` is list[float], ``sizes`` is used for all feature maps.
|
| 107 |
+
Anchor sizes are given in absolute lengths in units of
|
| 108 |
+
the input image; they do not dynamically scale if the input image size changes.
|
| 109 |
+
aspect_ratios (list[list[float]] or list[float]): list of aspect ratios
|
| 110 |
+
(i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies.
|
| 111 |
+
strides (list[int]): stride of each input feature.
|
| 112 |
+
offset (float): Relative offset between the center of the first anchor and the top-left
|
| 113 |
+
corner of the image. Value has to be in [0, 1).
|
| 114 |
+
Recommend to use 0.5, which means half stride.
|
| 115 |
+
"""
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
self.strides = strides
|
| 119 |
+
self.num_features = len(self.strides)
|
| 120 |
+
sizes = _broadcast_params(sizes, self.num_features, "sizes")
|
| 121 |
+
aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios")
|
| 122 |
+
self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios)
|
| 123 |
+
|
| 124 |
+
self.offset = offset
|
| 125 |
+
assert 0.0 <= self.offset < 1.0, self.offset
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def from_config(cls, cfg, input_shape: List[ShapeSpec]):
|
| 129 |
+
return {
|
| 130 |
+
"sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
|
| 131 |
+
"aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS,
|
| 132 |
+
"strides": [x.stride for x in input_shape],
|
| 133 |
+
"offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
def _calculate_anchors(self, sizes, aspect_ratios):
|
| 137 |
+
cell_anchors = [
|
| 138 |
+
self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios)
|
| 139 |
+
]
|
| 140 |
+
return BufferList(cell_anchors)
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
@torch.jit.unused
|
| 144 |
+
def num_cell_anchors(self):
|
| 145 |
+
"""
|
| 146 |
+
Alias of `num_anchors`.
|
| 147 |
+
"""
|
| 148 |
+
return self.num_anchors
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
@torch.jit.unused
|
| 152 |
+
def num_anchors(self):
|
| 153 |
+
"""
|
| 154 |
+
Returns:
|
| 155 |
+
list[int]: Each int is the number of anchors at every pixel
|
| 156 |
+
location, on that feature map.
|
| 157 |
+
For example, if at every pixel we use anchors of 3 aspect
|
| 158 |
+
ratios and 5 sizes, the number of anchors is 15.
|
| 159 |
+
(See also ANCHOR_GENERATOR.SIZES and ANCHOR_GENERATOR.ASPECT_RATIOS in config)
|
| 160 |
+
|
| 161 |
+
In standard RPN models, `num_anchors` on every feature map is the same.
|
| 162 |
+
"""
|
| 163 |
+
return [len(cell_anchors) for cell_anchors in self.cell_anchors]
|
| 164 |
+
|
| 165 |
+
def _grid_anchors(self, grid_sizes: List[List[int]]):
|
| 166 |
+
"""
|
| 167 |
+
Returns:
|
| 168 |
+
list[Tensor]: #featuremap tensors, each is (#locations x #cell_anchors) x 4
|
| 169 |
+
"""
|
| 170 |
+
anchors = []
|
| 171 |
+
# buffers() not supported by torchscript. use named_buffers() instead
|
| 172 |
+
buffers: List[torch.Tensor] = [x[1] for x in self.cell_anchors.named_buffers()]
|
| 173 |
+
for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers):
|
| 174 |
+
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
|
| 175 |
+
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
| 176 |
+
|
| 177 |
+
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
|
| 178 |
+
|
| 179 |
+
return anchors
|
| 180 |
+
|
| 181 |
+
def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)):
|
| 182 |
+
"""
|
| 183 |
+
Generate a tensor storing canonical anchor boxes, which are all anchor
|
| 184 |
+
boxes of different sizes and aspect_ratios centered at (0, 0).
|
| 185 |
+
We can later build the set of anchors for a full feature map by
|
| 186 |
+
shifting and tiling these tensors (see `meth:_grid_anchors`).
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
sizes (tuple[float]):
|
| 190 |
+
aspect_ratios (tuple[float]]):
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Tensor of shape (len(sizes) * len(aspect_ratios), 4) storing anchor boxes
|
| 194 |
+
in XYXY format.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
# This is different from the anchor generator defined in the original Faster R-CNN
|
| 198 |
+
# code or Detectron. They yield the same AP, however the old version defines cell
|
| 199 |
+
# anchors in a less natural way with a shift relative to the feature grid and
|
| 200 |
+
# quantization that results in slightly different sizes for different aspect ratios.
|
| 201 |
+
# See also https://github.com/facebookresearch/Detectron/issues/227
|
| 202 |
+
|
| 203 |
+
anchors = []
|
| 204 |
+
for size in sizes:
|
| 205 |
+
area = size**2.0
|
| 206 |
+
for aspect_ratio in aspect_ratios:
|
| 207 |
+
# s * s = w * h
|
| 208 |
+
# a = h / w
|
| 209 |
+
# ... some algebra ...
|
| 210 |
+
# w = sqrt(s * s / a)
|
| 211 |
+
# h = a * w
|
| 212 |
+
w = math.sqrt(area / aspect_ratio)
|
| 213 |
+
h = aspect_ratio * w
|
| 214 |
+
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
|
| 215 |
+
anchors.append([x0, y0, x1, y1])
|
| 216 |
+
return torch.tensor(anchors)
|
| 217 |
+
|
| 218 |
+
def forward(self, features: List[torch.Tensor]):
|
| 219 |
+
"""
|
| 220 |
+
Args:
|
| 221 |
+
features (list[Tensor]): list of backbone feature maps on which to generate anchors.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
list[Boxes]: a list of Boxes containing all the anchors for each feature map
|
| 225 |
+
(i.e. the cell anchors repeated over all locations in the feature map).
|
| 226 |
+
The number of anchors of each feature map is Hi x Wi x num_cell_anchors,
|
| 227 |
+
where Hi, Wi are resolution of the feature map divided by anchor stride.
|
| 228 |
+
"""
|
| 229 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in features]
|
| 230 |
+
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) # pyre-ignore
|
| 231 |
+
return [Boxes(x) for x in anchors_over_all_feature_maps]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@ANCHOR_GENERATOR_REGISTRY.register()
|
| 235 |
+
class RotatedAnchorGenerator(nn.Module):
|
| 236 |
+
"""
|
| 237 |
+
Compute rotated anchors used by Rotated RPN (RRPN), described in
|
| 238 |
+
"Arbitrary-Oriented Scene Text Detection via Rotation Proposals".
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
box_dim: int = 5
|
| 242 |
+
"""
|
| 243 |
+
the dimension of each anchor box.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
@configurable
|
| 247 |
+
def __init__(self, *, sizes, aspect_ratios, strides, angles, offset=0.5):
|
| 248 |
+
"""
|
| 249 |
+
This interface is experimental.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
sizes (list[list[float]] or list[float]):
|
| 253 |
+
If sizes is list[list[float]], sizes[i] is the list of anchor sizes
|
| 254 |
+
(i.e. sqrt of anchor area) to use for the i-th feature map.
|
| 255 |
+
If sizes is list[float], the sizes are used for all feature maps.
|
| 256 |
+
Anchor sizes are given in absolute lengths in units of
|
| 257 |
+
the input image; they do not dynamically scale if the input image size changes.
|
| 258 |
+
aspect_ratios (list[list[float]] or list[float]): list of aspect ratios
|
| 259 |
+
(i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies.
|
| 260 |
+
strides (list[int]): stride of each input feature.
|
| 261 |
+
angles (list[list[float]] or list[float]): list of angles (in degrees CCW)
|
| 262 |
+
to use for anchors. Same "broadcast" rule for `sizes` applies.
|
| 263 |
+
offset (float): Relative offset between the center of the first anchor and the top-left
|
| 264 |
+
corner of the image. Value has to be in [0, 1).
|
| 265 |
+
Recommend to use 0.5, which means half stride.
|
| 266 |
+
"""
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
self.strides = strides
|
| 270 |
+
self.num_features = len(self.strides)
|
| 271 |
+
sizes = _broadcast_params(sizes, self.num_features, "sizes")
|
| 272 |
+
aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios")
|
| 273 |
+
angles = _broadcast_params(angles, self.num_features, "angles")
|
| 274 |
+
self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios, angles)
|
| 275 |
+
|
| 276 |
+
self.offset = offset
|
| 277 |
+
assert 0.0 <= self.offset < 1.0, self.offset
|
| 278 |
+
|
| 279 |
+
@classmethod
|
| 280 |
+
def from_config(cls, cfg, input_shape: List[ShapeSpec]):
|
| 281 |
+
return {
|
| 282 |
+
"sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
|
| 283 |
+
"aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS,
|
| 284 |
+
"strides": [x.stride for x in input_shape],
|
| 285 |
+
"offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
|
| 286 |
+
"angles": cfg.MODEL.ANCHOR_GENERATOR.ANGLES,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
def _calculate_anchors(self, sizes, aspect_ratios, angles):
|
| 290 |
+
cell_anchors = [
|
| 291 |
+
self.generate_cell_anchors(size, aspect_ratio, angle).float()
|
| 292 |
+
for size, aspect_ratio, angle in zip(sizes, aspect_ratios, angles)
|
| 293 |
+
]
|
| 294 |
+
return BufferList(cell_anchors)
|
| 295 |
+
|
| 296 |
+
@property
|
| 297 |
+
def num_cell_anchors(self):
|
| 298 |
+
"""
|
| 299 |
+
Alias of `num_anchors`.
|
| 300 |
+
"""
|
| 301 |
+
return self.num_anchors
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def num_anchors(self):
|
| 305 |
+
"""
|
| 306 |
+
Returns:
|
| 307 |
+
list[int]: Each int is the number of anchors at every pixel
|
| 308 |
+
location, on that feature map.
|
| 309 |
+
For example, if at every pixel we use anchors of 3 aspect
|
| 310 |
+
ratios, 2 sizes and 5 angles, the number of anchors is 30.
|
| 311 |
+
(See also ANCHOR_GENERATOR.SIZES, ANCHOR_GENERATOR.ASPECT_RATIOS
|
| 312 |
+
and ANCHOR_GENERATOR.ANGLES in config)
|
| 313 |
+
|
| 314 |
+
In standard RRPN models, `num_anchors` on every feature map is the same.
|
| 315 |
+
"""
|
| 316 |
+
return [len(cell_anchors) for cell_anchors in self.cell_anchors]
|
| 317 |
+
|
| 318 |
+
def _grid_anchors(self, grid_sizes: List[List[int]]):
|
| 319 |
+
anchors = []
|
| 320 |
+
for size, stride, base_anchors in zip(
|
| 321 |
+
grid_sizes,
|
| 322 |
+
self.strides,
|
| 323 |
+
self.cell_anchors._buffers.values(),
|
| 324 |
+
):
|
| 325 |
+
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
|
| 326 |
+
zeros = torch.zeros_like(shift_x)
|
| 327 |
+
shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1)
|
| 328 |
+
|
| 329 |
+
anchors.append((shifts.view(-1, 1, 5) + base_anchors.view(1, -1, 5)).reshape(-1, 5))
|
| 330 |
+
|
| 331 |
+
return anchors
|
| 332 |
+
|
| 333 |
+
def generate_cell_anchors(
|
| 334 |
+
self,
|
| 335 |
+
sizes=(32, 64, 128, 256, 512),
|
| 336 |
+
aspect_ratios=(0.5, 1, 2),
|
| 337 |
+
angles=(-90, -60, -30, 0, 30, 60, 90),
|
| 338 |
+
):
|
| 339 |
+
"""
|
| 340 |
+
Generate a tensor storing canonical anchor boxes, which are all anchor
|
| 341 |
+
boxes of different sizes, aspect_ratios, angles centered at (0, 0).
|
| 342 |
+
We can later build the set of anchors for a full feature map by
|
| 343 |
+
shifting and tiling these tensors (see `meth:_grid_anchors`).
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
sizes (tuple[float]):
|
| 347 |
+
aspect_ratios (tuple[float]]):
|
| 348 |
+
angles (tuple[float]]):
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Tensor of shape (len(sizes) * len(aspect_ratios) * len(angles), 5)
|
| 352 |
+
storing anchor boxes in (x_ctr, y_ctr, w, h, angle) format.
|
| 353 |
+
"""
|
| 354 |
+
anchors = []
|
| 355 |
+
for size in sizes:
|
| 356 |
+
area = size**2.0
|
| 357 |
+
for aspect_ratio in aspect_ratios:
|
| 358 |
+
# s * s = w * h
|
| 359 |
+
# a = h / w
|
| 360 |
+
# ... some algebra ...
|
| 361 |
+
# w = sqrt(s * s / a)
|
| 362 |
+
# h = a * w
|
| 363 |
+
w = math.sqrt(area / aspect_ratio)
|
| 364 |
+
h = aspect_ratio * w
|
| 365 |
+
anchors.extend([0, 0, w, h, a] for a in angles)
|
| 366 |
+
|
| 367 |
+
return torch.tensor(anchors)
|
| 368 |
+
|
| 369 |
+
def forward(self, features):
|
| 370 |
+
"""
|
| 371 |
+
Args:
|
| 372 |
+
features (list[Tensor]): list of backbone feature maps on which to generate anchors.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
list[RotatedBoxes]: a list of Boxes containing all the anchors for each feature map
|
| 376 |
+
(i.e. the cell anchors repeated over all locations in the feature map).
|
| 377 |
+
The number of anchors of each feature map is Hi x Wi x num_cell_anchors,
|
| 378 |
+
where Hi, Wi are resolution of the feature map divided by anchor stride.
|
| 379 |
+
"""
|
| 380 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in features]
|
| 381 |
+
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
|
| 382 |
+
return [RotatedBoxes(x) for x in anchors_over_all_feature_maps]
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def build_anchor_generator(cfg, input_shape):
|
| 386 |
+
"""
|
| 387 |
+
Built an anchor generator from `cfg.MODEL.ANCHOR_GENERATOR.NAME`.
|
| 388 |
+
"""
|
| 389 |
+
anchor_generator = cfg.MODEL.ANCHOR_GENERATOR.NAME
|
| 390 |
+
return ANCHOR_GENERATOR_REGISTRY.get(anchor_generator)(cfg, input_shape)
|
CatVTON/detectron2/modeling/box_regression.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Tuple, Union
|
| 4 |
+
import torch
|
| 5 |
+
from fvcore.nn import giou_loss, smooth_l1_loss
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from detectron2.layers import cat, ciou_loss, diou_loss
|
| 9 |
+
from detectron2.structures import Boxes
|
| 10 |
+
|
| 11 |
+
# Value for clamping large dw and dh predictions. The heuristic is that we clamp
|
| 12 |
+
# such that dw and dh are no larger than what would transform a 16px box into a
|
| 13 |
+
# 1000px box (based on a small anchor, 16px, and a typical image size, 1000px).
|
| 14 |
+
_DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = ["Box2BoxTransform", "Box2BoxTransformRotated", "Box2BoxTransformLinear"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.jit.script
|
| 21 |
+
class Box2BoxTransform:
|
| 22 |
+
"""
|
| 23 |
+
The box-to-box transform defined in R-CNN. The transformation is parameterized
|
| 24 |
+
by 4 deltas: (dx, dy, dw, dh). The transformation scales the box's width and height
|
| 25 |
+
by exp(dw), exp(dh) and shifts a box's center by the offset (dx * width, dy * height).
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self, weights: Tuple[float, float, float, float], scale_clamp: float = _DEFAULT_SCALE_CLAMP
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
weights (4-element tuple): Scaling factors that are applied to the
|
| 34 |
+
(dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set
|
| 35 |
+
such that the deltas have unit variance; now they are treated as
|
| 36 |
+
hyperparameters of the system.
|
| 37 |
+
scale_clamp (float): When predicting deltas, the predicted box scaling
|
| 38 |
+
factors (dw and dh) are clamped such that they are <= scale_clamp.
|
| 39 |
+
"""
|
| 40 |
+
self.weights = weights
|
| 41 |
+
self.scale_clamp = scale_clamp
|
| 42 |
+
|
| 43 |
+
def get_deltas(self, src_boxes, target_boxes):
|
| 44 |
+
"""
|
| 45 |
+
Get box regression transformation deltas (dx, dy, dw, dh) that can be used
|
| 46 |
+
to transform the `src_boxes` into the `target_boxes`. That is, the relation
|
| 47 |
+
``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
|
| 48 |
+
any delta is too large and is clamped).
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
src_boxes (Tensor): source boxes, e.g., object proposals
|
| 52 |
+
target_boxes (Tensor): target of the transformation, e.g., ground-truth
|
| 53 |
+
boxes.
|
| 54 |
+
"""
|
| 55 |
+
assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
|
| 56 |
+
assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
|
| 57 |
+
|
| 58 |
+
src_widths = src_boxes[:, 2] - src_boxes[:, 0]
|
| 59 |
+
src_heights = src_boxes[:, 3] - src_boxes[:, 1]
|
| 60 |
+
src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths
|
| 61 |
+
src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights
|
| 62 |
+
|
| 63 |
+
target_widths = target_boxes[:, 2] - target_boxes[:, 0]
|
| 64 |
+
target_heights = target_boxes[:, 3] - target_boxes[:, 1]
|
| 65 |
+
target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths
|
| 66 |
+
target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights
|
| 67 |
+
|
| 68 |
+
wx, wy, ww, wh = self.weights
|
| 69 |
+
dx = wx * (target_ctr_x - src_ctr_x) / src_widths
|
| 70 |
+
dy = wy * (target_ctr_y - src_ctr_y) / src_heights
|
| 71 |
+
dw = ww * torch.log(target_widths / src_widths)
|
| 72 |
+
dh = wh * torch.log(target_heights / src_heights)
|
| 73 |
+
|
| 74 |
+
deltas = torch.stack((dx, dy, dw, dh), dim=1)
|
| 75 |
+
assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!"
|
| 76 |
+
return deltas
|
| 77 |
+
|
| 78 |
+
def apply_deltas(self, deltas, boxes):
|
| 79 |
+
"""
|
| 80 |
+
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
|
| 84 |
+
deltas[i] represents k potentially different class-specific
|
| 85 |
+
box transformations for the single box boxes[i].
|
| 86 |
+
boxes (Tensor): boxes to transform, of shape (N, 4)
|
| 87 |
+
"""
|
| 88 |
+
deltas = deltas.float() # ensure fp32 for decoding precision
|
| 89 |
+
boxes = boxes.to(deltas.dtype)
|
| 90 |
+
|
| 91 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
| 92 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
| 93 |
+
ctr_x = boxes[:, 0] + 0.5 * widths
|
| 94 |
+
ctr_y = boxes[:, 1] + 0.5 * heights
|
| 95 |
+
|
| 96 |
+
wx, wy, ww, wh = self.weights
|
| 97 |
+
dx = deltas[:, 0::4] / wx
|
| 98 |
+
dy = deltas[:, 1::4] / wy
|
| 99 |
+
dw = deltas[:, 2::4] / ww
|
| 100 |
+
dh = deltas[:, 3::4] / wh
|
| 101 |
+
|
| 102 |
+
# Prevent sending too large values into torch.exp()
|
| 103 |
+
dw = torch.clamp(dw, max=self.scale_clamp)
|
| 104 |
+
dh = torch.clamp(dh, max=self.scale_clamp)
|
| 105 |
+
|
| 106 |
+
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
| 107 |
+
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
| 108 |
+
pred_w = torch.exp(dw) * widths[:, None]
|
| 109 |
+
pred_h = torch.exp(dh) * heights[:, None]
|
| 110 |
+
|
| 111 |
+
x1 = pred_ctr_x - 0.5 * pred_w
|
| 112 |
+
y1 = pred_ctr_y - 0.5 * pred_h
|
| 113 |
+
x2 = pred_ctr_x + 0.5 * pred_w
|
| 114 |
+
y2 = pred_ctr_y + 0.5 * pred_h
|
| 115 |
+
pred_boxes = torch.stack((x1, y1, x2, y2), dim=-1)
|
| 116 |
+
return pred_boxes.reshape(deltas.shape)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@torch.jit.script
|
| 120 |
+
class Box2BoxTransformRotated:
|
| 121 |
+
"""
|
| 122 |
+
The box-to-box transform defined in Rotated R-CNN. The transformation is parameterized
|
| 123 |
+
by 5 deltas: (dx, dy, dw, dh, da). The transformation scales the box's width and height
|
| 124 |
+
by exp(dw), exp(dh), shifts a box's center by the offset (dx * width, dy * height),
|
| 125 |
+
and rotate a box's angle by da (radians).
|
| 126 |
+
Note: angles of deltas are in radians while angles of boxes are in degrees.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
weights: Tuple[float, float, float, float, float],
|
| 132 |
+
scale_clamp: float = _DEFAULT_SCALE_CLAMP,
|
| 133 |
+
):
|
| 134 |
+
"""
|
| 135 |
+
Args:
|
| 136 |
+
weights (5-element tuple): Scaling factors that are applied to the
|
| 137 |
+
(dx, dy, dw, dh, da) deltas. These are treated as
|
| 138 |
+
hyperparameters of the system.
|
| 139 |
+
scale_clamp (float): When predicting deltas, the predicted box scaling
|
| 140 |
+
factors (dw and dh) are clamped such that they are <= scale_clamp.
|
| 141 |
+
"""
|
| 142 |
+
self.weights = weights
|
| 143 |
+
self.scale_clamp = scale_clamp
|
| 144 |
+
|
| 145 |
+
def get_deltas(self, src_boxes, target_boxes):
|
| 146 |
+
"""
|
| 147 |
+
Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used
|
| 148 |
+
to transform the `src_boxes` into the `target_boxes`. That is, the relation
|
| 149 |
+
``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
|
| 150 |
+
any delta is too large and is clamped).
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
src_boxes (Tensor): Nx5 source boxes, e.g., object proposals
|
| 154 |
+
target_boxes (Tensor): Nx5 target of the transformation, e.g., ground-truth
|
| 155 |
+
boxes.
|
| 156 |
+
"""
|
| 157 |
+
assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
|
| 158 |
+
assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
|
| 159 |
+
|
| 160 |
+
src_ctr_x, src_ctr_y, src_widths, src_heights, src_angles = torch.unbind(src_boxes, dim=1)
|
| 161 |
+
|
| 162 |
+
target_ctr_x, target_ctr_y, target_widths, target_heights, target_angles = torch.unbind(
|
| 163 |
+
target_boxes, dim=1
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
wx, wy, ww, wh, wa = self.weights
|
| 167 |
+
dx = wx * (target_ctr_x - src_ctr_x) / src_widths
|
| 168 |
+
dy = wy * (target_ctr_y - src_ctr_y) / src_heights
|
| 169 |
+
dw = ww * torch.log(target_widths / src_widths)
|
| 170 |
+
dh = wh * torch.log(target_heights / src_heights)
|
| 171 |
+
# Angles of deltas are in radians while angles of boxes are in degrees.
|
| 172 |
+
# the conversion to radians serve as a way to normalize the values
|
| 173 |
+
da = target_angles - src_angles
|
| 174 |
+
da = (da + 180.0) % 360.0 - 180.0 # make it in [-180, 180)
|
| 175 |
+
da *= wa * math.pi / 180.0
|
| 176 |
+
|
| 177 |
+
deltas = torch.stack((dx, dy, dw, dh, da), dim=1)
|
| 178 |
+
assert (
|
| 179 |
+
(src_widths > 0).all().item()
|
| 180 |
+
), "Input boxes to Box2BoxTransformRotated are not valid!"
|
| 181 |
+
return deltas
|
| 182 |
+
|
| 183 |
+
def apply_deltas(self, deltas, boxes):
|
| 184 |
+
"""
|
| 185 |
+
Apply transformation `deltas` (dx, dy, dw, dh, da) to `boxes`.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
deltas (Tensor): transformation deltas of shape (N, k*5).
|
| 189 |
+
deltas[i] represents box transformation for the single box boxes[i].
|
| 190 |
+
boxes (Tensor): boxes to transform, of shape (N, 5)
|
| 191 |
+
"""
|
| 192 |
+
assert deltas.shape[1] % 5 == 0 and boxes.shape[1] == 5
|
| 193 |
+
|
| 194 |
+
boxes = boxes.to(deltas.dtype).unsqueeze(2)
|
| 195 |
+
|
| 196 |
+
ctr_x = boxes[:, 0]
|
| 197 |
+
ctr_y = boxes[:, 1]
|
| 198 |
+
widths = boxes[:, 2]
|
| 199 |
+
heights = boxes[:, 3]
|
| 200 |
+
angles = boxes[:, 4]
|
| 201 |
+
|
| 202 |
+
wx, wy, ww, wh, wa = self.weights
|
| 203 |
+
|
| 204 |
+
dx = deltas[:, 0::5] / wx
|
| 205 |
+
dy = deltas[:, 1::5] / wy
|
| 206 |
+
dw = deltas[:, 2::5] / ww
|
| 207 |
+
dh = deltas[:, 3::5] / wh
|
| 208 |
+
da = deltas[:, 4::5] / wa
|
| 209 |
+
|
| 210 |
+
# Prevent sending too large values into torch.exp()
|
| 211 |
+
dw = torch.clamp(dw, max=self.scale_clamp)
|
| 212 |
+
dh = torch.clamp(dh, max=self.scale_clamp)
|
| 213 |
+
|
| 214 |
+
pred_boxes = torch.zeros_like(deltas)
|
| 215 |
+
pred_boxes[:, 0::5] = dx * widths + ctr_x # x_ctr
|
| 216 |
+
pred_boxes[:, 1::5] = dy * heights + ctr_y # y_ctr
|
| 217 |
+
pred_boxes[:, 2::5] = torch.exp(dw) * widths # width
|
| 218 |
+
pred_boxes[:, 3::5] = torch.exp(dh) * heights # height
|
| 219 |
+
|
| 220 |
+
# Following original RRPN implementation,
|
| 221 |
+
# angles of deltas are in radians while angles of boxes are in degrees.
|
| 222 |
+
pred_angle = da * 180.0 / math.pi + angles
|
| 223 |
+
pred_angle = (pred_angle + 180.0) % 360.0 - 180.0 # make it in [-180, 180)
|
| 224 |
+
|
| 225 |
+
pred_boxes[:, 4::5] = pred_angle
|
| 226 |
+
|
| 227 |
+
return pred_boxes
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class Box2BoxTransformLinear:
|
| 231 |
+
"""
|
| 232 |
+
The linear box-to-box transform defined in FCOS. The transformation is parameterized
|
| 233 |
+
by the distance from the center of (square) src box to 4 edges of the target box.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
def __init__(self, normalize_by_size=True):
|
| 237 |
+
"""
|
| 238 |
+
Args:
|
| 239 |
+
normalize_by_size: normalize deltas by the size of src (anchor) boxes.
|
| 240 |
+
"""
|
| 241 |
+
self.normalize_by_size = normalize_by_size
|
| 242 |
+
|
| 243 |
+
def get_deltas(self, src_boxes, target_boxes):
|
| 244 |
+
"""
|
| 245 |
+
Get box regression transformation deltas (dx1, dy1, dx2, dy2) that can be used
|
| 246 |
+
to transform the `src_boxes` into the `target_boxes`. That is, the relation
|
| 247 |
+
``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true.
|
| 248 |
+
The center of src must be inside target boxes.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
src_boxes (Tensor): square source boxes, e.g., anchors
|
| 252 |
+
target_boxes (Tensor): target of the transformation, e.g., ground-truth
|
| 253 |
+
boxes.
|
| 254 |
+
"""
|
| 255 |
+
assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
|
| 256 |
+
assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
|
| 257 |
+
|
| 258 |
+
src_ctr_x = 0.5 * (src_boxes[:, 0] + src_boxes[:, 2])
|
| 259 |
+
src_ctr_y = 0.5 * (src_boxes[:, 1] + src_boxes[:, 3])
|
| 260 |
+
|
| 261 |
+
target_l = src_ctr_x - target_boxes[:, 0]
|
| 262 |
+
target_t = src_ctr_y - target_boxes[:, 1]
|
| 263 |
+
target_r = target_boxes[:, 2] - src_ctr_x
|
| 264 |
+
target_b = target_boxes[:, 3] - src_ctr_y
|
| 265 |
+
|
| 266 |
+
deltas = torch.stack((target_l, target_t, target_r, target_b), dim=1)
|
| 267 |
+
if self.normalize_by_size:
|
| 268 |
+
stride_w = src_boxes[:, 2] - src_boxes[:, 0]
|
| 269 |
+
stride_h = src_boxes[:, 3] - src_boxes[:, 1]
|
| 270 |
+
strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
|
| 271 |
+
deltas = deltas / strides
|
| 272 |
+
|
| 273 |
+
return deltas
|
| 274 |
+
|
| 275 |
+
def apply_deltas(self, deltas, boxes):
|
| 276 |
+
"""
|
| 277 |
+
Apply transformation `deltas` (dx1, dy1, dx2, dy2) to `boxes`.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
|
| 281 |
+
deltas[i] represents k potentially different class-specific
|
| 282 |
+
box transformations for the single box boxes[i].
|
| 283 |
+
boxes (Tensor): boxes to transform, of shape (N, 4)
|
| 284 |
+
"""
|
| 285 |
+
# Ensure the output is a valid box. See Sec 2.1 of https://arxiv.org/abs/2006.09214
|
| 286 |
+
deltas = F.relu(deltas)
|
| 287 |
+
boxes = boxes.to(deltas.dtype)
|
| 288 |
+
|
| 289 |
+
ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
|
| 290 |
+
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
|
| 291 |
+
if self.normalize_by_size:
|
| 292 |
+
stride_w = boxes[:, 2] - boxes[:, 0]
|
| 293 |
+
stride_h = boxes[:, 3] - boxes[:, 1]
|
| 294 |
+
strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
|
| 295 |
+
deltas = deltas * strides
|
| 296 |
+
|
| 297 |
+
l = deltas[:, 0::4]
|
| 298 |
+
t = deltas[:, 1::4]
|
| 299 |
+
r = deltas[:, 2::4]
|
| 300 |
+
b = deltas[:, 3::4]
|
| 301 |
+
|
| 302 |
+
pred_boxes = torch.zeros_like(deltas)
|
| 303 |
+
pred_boxes[:, 0::4] = ctr_x[:, None] - l # x1
|
| 304 |
+
pred_boxes[:, 1::4] = ctr_y[:, None] - t # y1
|
| 305 |
+
pred_boxes[:, 2::4] = ctr_x[:, None] + r # x2
|
| 306 |
+
pred_boxes[:, 3::4] = ctr_y[:, None] + b # y2
|
| 307 |
+
return pred_boxes
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _dense_box_regression_loss(
|
| 311 |
+
anchors: List[Union[Boxes, torch.Tensor]],
|
| 312 |
+
box2box_transform: Box2BoxTransform,
|
| 313 |
+
pred_anchor_deltas: List[torch.Tensor],
|
| 314 |
+
gt_boxes: List[torch.Tensor],
|
| 315 |
+
fg_mask: torch.Tensor,
|
| 316 |
+
box_reg_loss_type="smooth_l1",
|
| 317 |
+
smooth_l1_beta=0.0,
|
| 318 |
+
):
|
| 319 |
+
"""
|
| 320 |
+
Compute loss for dense multi-level box regression.
|
| 321 |
+
Loss is accumulated over ``fg_mask``.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
anchors: #lvl anchor boxes, each is (HixWixA, 4)
|
| 325 |
+
pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4)
|
| 326 |
+
gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A))
|
| 327 |
+
fg_mask: the foreground boolean mask of shape (N, R) to compute loss on
|
| 328 |
+
box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou",
|
| 329 |
+
"diou", "ciou".
|
| 330 |
+
smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to
|
| 331 |
+
use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1"
|
| 332 |
+
"""
|
| 333 |
+
if isinstance(anchors[0], Boxes):
|
| 334 |
+
anchors = type(anchors[0]).cat(anchors).tensor # (R, 4)
|
| 335 |
+
else:
|
| 336 |
+
anchors = cat(anchors)
|
| 337 |
+
if box_reg_loss_type == "smooth_l1":
|
| 338 |
+
gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
|
| 339 |
+
gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4)
|
| 340 |
+
loss_box_reg = smooth_l1_loss(
|
| 341 |
+
cat(pred_anchor_deltas, dim=1)[fg_mask],
|
| 342 |
+
gt_anchor_deltas[fg_mask],
|
| 343 |
+
beta=smooth_l1_beta,
|
| 344 |
+
reduction="sum",
|
| 345 |
+
)
|
| 346 |
+
elif box_reg_loss_type == "giou":
|
| 347 |
+
pred_boxes = [
|
| 348 |
+
box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
|
| 349 |
+
]
|
| 350 |
+
loss_box_reg = giou_loss(
|
| 351 |
+
torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
|
| 352 |
+
)
|
| 353 |
+
elif box_reg_loss_type == "diou":
|
| 354 |
+
pred_boxes = [
|
| 355 |
+
box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
|
| 356 |
+
]
|
| 357 |
+
loss_box_reg = diou_loss(
|
| 358 |
+
torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
|
| 359 |
+
)
|
| 360 |
+
elif box_reg_loss_type == "ciou":
|
| 361 |
+
pred_boxes = [
|
| 362 |
+
box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
|
| 363 |
+
]
|
| 364 |
+
loss_box_reg = ciou_loss(
|
| 365 |
+
torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'")
|
| 369 |
+
return loss_box_reg
|
CatVTON/detectron2/modeling/matcher.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from typing import List
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from detectron2.layers import nonzero_tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# TODO: the name is too general
|
| 9 |
+
class Matcher:
|
| 10 |
+
"""
|
| 11 |
+
This class assigns to each predicted "element" (e.g., a box) a ground-truth
|
| 12 |
+
element. Each predicted element will have exactly zero or one matches; each
|
| 13 |
+
ground-truth element may be matched to zero or more predicted elements.
|
| 14 |
+
|
| 15 |
+
The matching is determined by the MxN match_quality_matrix, that characterizes
|
| 16 |
+
how well each (ground-truth, prediction)-pair match each other. For example,
|
| 17 |
+
if the elements are boxes, this matrix may contain box intersection-over-union
|
| 18 |
+
overlap values.
|
| 19 |
+
|
| 20 |
+
The matcher returns (a) a vector of length N containing the index of the
|
| 21 |
+
ground-truth element m in [0, M) that matches to prediction n in [0, N).
|
| 22 |
+
(b) a vector of length N containing the labels for each prediction.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
thresholds (list): a list of thresholds used to stratify predictions
|
| 31 |
+
into levels.
|
| 32 |
+
labels (list): a list of values to label predictions belonging at
|
| 33 |
+
each level. A label can be one of {-1, 0, 1} signifying
|
| 34 |
+
{ignore, negative class, positive class}, respectively.
|
| 35 |
+
allow_low_quality_matches (bool): if True, produce additional matches
|
| 36 |
+
for predictions with maximum match quality lower than high_threshold.
|
| 37 |
+
See set_low_quality_matches_ for more details.
|
| 38 |
+
|
| 39 |
+
For example,
|
| 40 |
+
thresholds = [0.3, 0.5]
|
| 41 |
+
labels = [0, -1, 1]
|
| 42 |
+
All predictions with iou < 0.3 will be marked with 0 and
|
| 43 |
+
thus will be considered as false positives while training.
|
| 44 |
+
All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
|
| 45 |
+
thus will be ignored.
|
| 46 |
+
All predictions with 0.5 <= iou will be marked with 1 and
|
| 47 |
+
thus will be considered as true positives.
|
| 48 |
+
"""
|
| 49 |
+
# Add -inf and +inf to first and last position in thresholds
|
| 50 |
+
thresholds = thresholds[:]
|
| 51 |
+
assert thresholds[0] > 0
|
| 52 |
+
thresholds.insert(0, -float("inf"))
|
| 53 |
+
thresholds.append(float("inf"))
|
| 54 |
+
# Currently torchscript does not support all + generator
|
| 55 |
+
assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])])
|
| 56 |
+
assert all([l in [-1, 0, 1] for l in labels])
|
| 57 |
+
assert len(labels) == len(thresholds) - 1
|
| 58 |
+
self.thresholds = thresholds
|
| 59 |
+
self.labels = labels
|
| 60 |
+
self.allow_low_quality_matches = allow_low_quality_matches
|
| 61 |
+
|
| 62 |
+
def __call__(self, match_quality_matrix):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
|
| 66 |
+
pairwise quality between M ground-truth elements and N predicted
|
| 67 |
+
elements. All elements must be >= 0 (due to the us of `torch.nonzero`
|
| 68 |
+
for selecting indices in :meth:`set_low_quality_matches_`).
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
|
| 72 |
+
ground-truth index in [0, M)
|
| 73 |
+
match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
|
| 74 |
+
whether a prediction is a true or false positive or ignored
|
| 75 |
+
"""
|
| 76 |
+
assert match_quality_matrix.dim() == 2
|
| 77 |
+
if match_quality_matrix.numel() == 0:
|
| 78 |
+
default_matches = match_quality_matrix.new_full(
|
| 79 |
+
(match_quality_matrix.size(1),), 0, dtype=torch.int64
|
| 80 |
+
)
|
| 81 |
+
# When no gt boxes exist, we define IOU = 0 and therefore set labels
|
| 82 |
+
# to `self.labels[0]`, which usually defaults to background class 0
|
| 83 |
+
# To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
|
| 84 |
+
default_match_labels = match_quality_matrix.new_full(
|
| 85 |
+
(match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
|
| 86 |
+
)
|
| 87 |
+
return default_matches, default_match_labels
|
| 88 |
+
|
| 89 |
+
assert torch.all(match_quality_matrix >= 0)
|
| 90 |
+
|
| 91 |
+
# match_quality_matrix is M (gt) x N (predicted)
|
| 92 |
+
# Max over gt elements (dim 0) to find best gt candidate for each prediction
|
| 93 |
+
matched_vals, matches = match_quality_matrix.max(dim=0)
|
| 94 |
+
|
| 95 |
+
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
| 96 |
+
|
| 97 |
+
for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
| 98 |
+
low_high = (matched_vals >= low) & (matched_vals < high)
|
| 99 |
+
match_labels[low_high] = l
|
| 100 |
+
|
| 101 |
+
if self.allow_low_quality_matches:
|
| 102 |
+
self.set_low_quality_matches_(match_labels, match_quality_matrix)
|
| 103 |
+
|
| 104 |
+
return matches, match_labels
|
| 105 |
+
|
| 106 |
+
def set_low_quality_matches_(self, match_labels, match_quality_matrix):
|
| 107 |
+
"""
|
| 108 |
+
Produce additional matches for predictions that have only low-quality matches.
|
| 109 |
+
Specifically, for each ground-truth G find the set of predictions that have
|
| 110 |
+
maximum overlap with it (including ties); for each prediction in that set, if
|
| 111 |
+
it is unmatched, then match it to the ground-truth G.
|
| 112 |
+
|
| 113 |
+
This function implements the RPN assignment case (i) in Sec. 3.1.2 of
|
| 114 |
+
:paper:`Faster R-CNN`.
|
| 115 |
+
"""
|
| 116 |
+
# For each gt, find the prediction with which it has highest quality
|
| 117 |
+
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
|
| 118 |
+
# Find the highest quality match available, even if it is low, including ties.
|
| 119 |
+
# Note that the matches qualities must be positive due to the use of
|
| 120 |
+
# `torch.nonzero`.
|
| 121 |
+
_, pred_inds_with_highest_quality = nonzero_tuple(
|
| 122 |
+
match_quality_matrix == highest_quality_foreach_gt[:, None]
|
| 123 |
+
)
|
| 124 |
+
# If an anchor was labeled positive only due to a low-quality match
|
| 125 |
+
# with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B.
|
| 126 |
+
# This follows the implementation in Detectron, and is found to have no significant impact.
|
| 127 |
+
match_labels[pred_inds_with_highest_quality] = 1
|
CatVTON/detectron2/modeling/poolers.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision.ops import RoIPool
|
| 7 |
+
|
| 8 |
+
from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
|
| 9 |
+
from detectron2.structures import Boxes
|
| 10 |
+
from detectron2.utils.tracing import assert_fx_safe, is_fx_tracing
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
To export ROIPooler to torchscript, in this file, variables that should be annotated with
|
| 14 |
+
`Union[List[Boxes], List[RotatedBoxes]]` are only annotated with `List[Boxes]`.
|
| 15 |
+
|
| 16 |
+
TODO: Correct these annotations when torchscript support `Union`.
|
| 17 |
+
https://github.com/pytorch/pytorch/issues/41412
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
__all__ = ["ROIPooler"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def assign_boxes_to_levels(
|
| 24 |
+
box_lists: List[Boxes],
|
| 25 |
+
min_level: int,
|
| 26 |
+
max_level: int,
|
| 27 |
+
canonical_box_size: int,
|
| 28 |
+
canonical_level: int,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Map each box in `box_lists` to a feature map level index and return the assignment
|
| 32 |
+
vector.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
box_lists (list[Boxes] | list[RotatedBoxes]): A list of N Boxes or N RotatedBoxes,
|
| 36 |
+
where N is the number of images in the batch.
|
| 37 |
+
min_level (int): Smallest feature map level index. The input is considered index 0,
|
| 38 |
+
the output of stage 1 is index 1, and so.
|
| 39 |
+
max_level (int): Largest feature map level index.
|
| 40 |
+
canonical_box_size (int): A canonical box size in pixels (sqrt(box area)).
|
| 41 |
+
canonical_level (int): The feature map level index on which a canonically-sized box
|
| 42 |
+
should be placed.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
A tensor of length M, where M is the total number of boxes aggregated over all
|
| 46 |
+
N batch images. The memory layout corresponds to the concatenation of boxes
|
| 47 |
+
from all images. Each element is the feature map index, as an offset from
|
| 48 |
+
`self.min_level`, for the corresponding box (so value i means the box is at
|
| 49 |
+
`self.min_level + i`).
|
| 50 |
+
"""
|
| 51 |
+
box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists]))
|
| 52 |
+
# Eqn.(1) in FPN paper
|
| 53 |
+
level_assignments = torch.floor(
|
| 54 |
+
canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
|
| 55 |
+
)
|
| 56 |
+
# clamp level to (min, max), in case the box size is too large or too small
|
| 57 |
+
# for the available feature maps
|
| 58 |
+
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
|
| 59 |
+
return level_assignments.to(torch.int64) - min_level
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# script the module to avoid hardcoded device type
|
| 63 |
+
@torch.jit.script_if_tracing
|
| 64 |
+
def _convert_boxes_to_pooler_format(boxes: torch.Tensor, sizes: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
sizes = sizes.to(device=boxes.device)
|
| 66 |
+
indices = torch.repeat_interleave(
|
| 67 |
+
torch.arange(len(sizes), dtype=boxes.dtype, device=boxes.device), sizes
|
| 68 |
+
)
|
| 69 |
+
return cat([indices[:, None], boxes], dim=1)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
|
| 73 |
+
"""
|
| 74 |
+
Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops
|
| 75 |
+
(see description under Returns).
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
box_lists (list[Boxes] | list[RotatedBoxes]):
|
| 79 |
+
A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
When input is list[Boxes]:
|
| 83 |
+
A tensor of shape (M, 5), where M is the total number of boxes aggregated over all
|
| 84 |
+
N batch images.
|
| 85 |
+
The 5 columns are (batch index, x0, y0, x1, y1), where batch index
|
| 86 |
+
is the index in [0, N) identifying which batch image the box with corners at
|
| 87 |
+
(x0, y0, x1, y1) comes from.
|
| 88 |
+
When input is list[RotatedBoxes]:
|
| 89 |
+
A tensor of shape (M, 6), where M is the total number of boxes aggregated over all
|
| 90 |
+
N batch images.
|
| 91 |
+
The 6 columns are (batch index, x_ctr, y_ctr, width, height, angle_degrees),
|
| 92 |
+
where batch index is the index in [0, N) identifying which batch image the
|
| 93 |
+
rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from.
|
| 94 |
+
"""
|
| 95 |
+
boxes = torch.cat([x.tensor for x in box_lists], dim=0)
|
| 96 |
+
# __len__ returns Tensor in tracing.
|
| 97 |
+
sizes = shapes_to_tensor([x.__len__() for x in box_lists])
|
| 98 |
+
return _convert_boxes_to_pooler_format(boxes, sizes)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@torch.jit.script_if_tracing
|
| 102 |
+
def _create_zeros(
|
| 103 |
+
batch_target: Optional[torch.Tensor],
|
| 104 |
+
channels: int,
|
| 105 |
+
height: int,
|
| 106 |
+
width: int,
|
| 107 |
+
like_tensor: torch.Tensor,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
batches = batch_target.shape[0] if batch_target is not None else 0
|
| 110 |
+
sizes = (batches, channels, height, width)
|
| 111 |
+
return torch.zeros(sizes, dtype=like_tensor.dtype, device=like_tensor.device)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ROIPooler(nn.Module):
|
| 115 |
+
"""
|
| 116 |
+
Region of interest feature map pooler that supports pooling from one or more
|
| 117 |
+
feature maps.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
output_size,
|
| 123 |
+
scales,
|
| 124 |
+
sampling_ratio,
|
| 125 |
+
pooler_type,
|
| 126 |
+
canonical_box_size=224,
|
| 127 |
+
canonical_level=4,
|
| 128 |
+
):
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
output_size (int, tuple[int] or list[int]): output size of the pooled region,
|
| 132 |
+
e.g., 14 x 14. If tuple or list is given, the length must be 2.
|
| 133 |
+
scales (list[float]): The scale for each low-level pooling op relative to
|
| 134 |
+
the input image. For a feature map with stride s relative to the input
|
| 135 |
+
image, scale is defined as 1/s. The stride must be power of 2.
|
| 136 |
+
When there are multiple scales, they must form a pyramid, i.e. they must be
|
| 137 |
+
a monotically decreasing geometric sequence with a factor of 1/2.
|
| 138 |
+
sampling_ratio (int): The `sampling_ratio` parameter for the ROIAlign op.
|
| 139 |
+
pooler_type (string): Name of the type of pooling operation that should be applied.
|
| 140 |
+
For instance, "ROIPool" or "ROIAlignV2".
|
| 141 |
+
canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). The default
|
| 142 |
+
is heuristically defined as 224 pixels in the FPN paper (based on ImageNet
|
| 143 |
+
pre-training).
|
| 144 |
+
canonical_level (int): The feature map level index from which a canonically-sized box
|
| 145 |
+
should be placed. The default is defined as level 4 (stride=16) in the FPN paper,
|
| 146 |
+
i.e., a box of size 224x224 will be placed on the feature with stride=16.
|
| 147 |
+
The box placement for all boxes will be determined from their sizes w.r.t
|
| 148 |
+
canonical_box_size. For example, a box whose area is 4x that of a canonical box
|
| 149 |
+
should be used to pool features from feature level ``canonical_level+1``.
|
| 150 |
+
|
| 151 |
+
Note that the actual input feature maps given to this module may not have
|
| 152 |
+
sufficiently many levels for the input boxes. If the boxes are too large or too
|
| 153 |
+
small for the input feature maps, the closest level will be used.
|
| 154 |
+
"""
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
if isinstance(output_size, int):
|
| 158 |
+
output_size = (output_size, output_size)
|
| 159 |
+
assert len(output_size) == 2
|
| 160 |
+
assert isinstance(output_size[0], int) and isinstance(output_size[1], int)
|
| 161 |
+
self.output_size = output_size
|
| 162 |
+
|
| 163 |
+
if pooler_type == "ROIAlign":
|
| 164 |
+
self.level_poolers = nn.ModuleList(
|
| 165 |
+
ROIAlign(
|
| 166 |
+
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=False
|
| 167 |
+
)
|
| 168 |
+
for scale in scales
|
| 169 |
+
)
|
| 170 |
+
elif pooler_type == "ROIAlignV2":
|
| 171 |
+
self.level_poolers = nn.ModuleList(
|
| 172 |
+
ROIAlign(
|
| 173 |
+
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True
|
| 174 |
+
)
|
| 175 |
+
for scale in scales
|
| 176 |
+
)
|
| 177 |
+
elif pooler_type == "ROIPool":
|
| 178 |
+
self.level_poolers = nn.ModuleList(
|
| 179 |
+
RoIPool(output_size, spatial_scale=scale) for scale in scales
|
| 180 |
+
)
|
| 181 |
+
elif pooler_type == "ROIAlignRotated":
|
| 182 |
+
self.level_poolers = nn.ModuleList(
|
| 183 |
+
ROIAlignRotated(output_size, spatial_scale=scale, sampling_ratio=sampling_ratio)
|
| 184 |
+
for scale in scales
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError("Unknown pooler type: {}".format(pooler_type))
|
| 188 |
+
|
| 189 |
+
# Map scale (defined as 1 / stride) to its feature map level under the
|
| 190 |
+
# assumption that stride is a power of 2.
|
| 191 |
+
min_level = -(math.log2(scales[0]))
|
| 192 |
+
max_level = -(math.log2(scales[-1]))
|
| 193 |
+
assert math.isclose(min_level, int(min_level)) and math.isclose(
|
| 194 |
+
max_level, int(max_level)
|
| 195 |
+
), "Featuremap stride is not power of 2!"
|
| 196 |
+
self.min_level = int(min_level)
|
| 197 |
+
self.max_level = int(max_level)
|
| 198 |
+
assert (
|
| 199 |
+
len(scales) == self.max_level - self.min_level + 1
|
| 200 |
+
), "[ROIPooler] Sizes of input featuremaps do not form a pyramid!"
|
| 201 |
+
assert 0 <= self.min_level and self.min_level <= self.max_level
|
| 202 |
+
self.canonical_level = canonical_level
|
| 203 |
+
assert canonical_box_size > 0
|
| 204 |
+
self.canonical_box_size = canonical_box_size
|
| 205 |
+
|
| 206 |
+
def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
|
| 207 |
+
"""
|
| 208 |
+
Args:
|
| 209 |
+
x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
|
| 210 |
+
used to construct this module.
|
| 211 |
+
box_lists (list[Boxes] | list[RotatedBoxes]):
|
| 212 |
+
A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch.
|
| 213 |
+
The box coordinates are defined on the original image and
|
| 214 |
+
will be scaled by the `scales` argument of :class:`ROIPooler`.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Tensor:
|
| 218 |
+
A tensor of shape (M, C, output_size, output_size) where M is the total number of
|
| 219 |
+
boxes aggregated over all N batch images and C is the number of channels in `x`.
|
| 220 |
+
"""
|
| 221 |
+
num_level_assignments = len(self.level_poolers)
|
| 222 |
+
|
| 223 |
+
if not is_fx_tracing():
|
| 224 |
+
torch._assert(
|
| 225 |
+
isinstance(x, list) and isinstance(box_lists, list),
|
| 226 |
+
"Arguments to pooler must be lists",
|
| 227 |
+
)
|
| 228 |
+
assert_fx_safe(
|
| 229 |
+
len(x) == num_level_assignments,
|
| 230 |
+
"unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
|
| 231 |
+
num_level_assignments, len(x)
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
assert_fx_safe(
|
| 235 |
+
len(box_lists) == x[0].size(0),
|
| 236 |
+
"unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
|
| 237 |
+
x[0].size(0), len(box_lists)
|
| 238 |
+
),
|
| 239 |
+
)
|
| 240 |
+
if len(box_lists) == 0:
|
| 241 |
+
return _create_zeros(None, x[0].shape[1], *self.output_size, x[0])
|
| 242 |
+
|
| 243 |
+
pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists)
|
| 244 |
+
|
| 245 |
+
if num_level_assignments == 1:
|
| 246 |
+
return self.level_poolers[0](x[0], pooler_fmt_boxes)
|
| 247 |
+
|
| 248 |
+
level_assignments = assign_boxes_to_levels(
|
| 249 |
+
box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
num_channels = x[0].shape[1]
|
| 253 |
+
output_size = self.output_size[0]
|
| 254 |
+
|
| 255 |
+
output = _create_zeros(pooler_fmt_boxes, num_channels, output_size, output_size, x[0])
|
| 256 |
+
|
| 257 |
+
for level, pooler in enumerate(self.level_poolers):
|
| 258 |
+
inds = nonzero_tuple(level_assignments == level)[0]
|
| 259 |
+
pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
|
| 260 |
+
# Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
|
| 261 |
+
output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level))
|
| 262 |
+
|
| 263 |
+
return output
|
CatVTON/detectron2/projects/README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Projects live in the [`projects` directory](../../projects) under the root of this repository, but not here.
|
CatVTON/detectron2/projects/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import importlib.abc
|
| 3 |
+
import importlib.util
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
__all__ = []
|
| 7 |
+
|
| 8 |
+
_PROJECTS = {
|
| 9 |
+
"point_rend": "PointRend",
|
| 10 |
+
"deeplab": "DeepLab",
|
| 11 |
+
"panoptic_deeplab": "Panoptic-DeepLab",
|
| 12 |
+
}
|
| 13 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent / "projects"
|
| 14 |
+
|
| 15 |
+
if _PROJECT_ROOT.is_dir():
|
| 16 |
+
# This is true only for in-place installation (pip install -e, setup.py develop),
|
| 17 |
+
# where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230
|
| 18 |
+
|
| 19 |
+
class _D2ProjectsFinder(importlib.abc.MetaPathFinder):
|
| 20 |
+
def find_spec(self, name, path, target=None):
|
| 21 |
+
if not name.startswith("detectron2.projects."):
|
| 22 |
+
return
|
| 23 |
+
project_name = name.split(".")[-1]
|
| 24 |
+
project_dir = _PROJECTS.get(project_name)
|
| 25 |
+
if not project_dir:
|
| 26 |
+
return
|
| 27 |
+
target_file = _PROJECT_ROOT / f"{project_dir}/{project_name}/__init__.py"
|
| 28 |
+
if not target_file.is_file():
|
| 29 |
+
return
|
| 30 |
+
return importlib.util.spec_from_file_location(name, target_file)
|
| 31 |
+
|
| 32 |
+
import sys
|
| 33 |
+
|
| 34 |
+
sys.meta_path.append(_D2ProjectsFinder())
|
CatVTON/detectron2/solver/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
|
| 3 |
+
from .lr_scheduler import (
|
| 4 |
+
LRMultiplier,
|
| 5 |
+
LRScheduler,
|
| 6 |
+
WarmupCosineLR,
|
| 7 |
+
WarmupMultiStepLR,
|
| 8 |
+
WarmupParamScheduler,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
CatVTON/detectron2/solver/build.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
|
| 8 |
+
import torch
|
| 9 |
+
from fvcore.common.param_scheduler import (
|
| 10 |
+
CosineParamScheduler,
|
| 11 |
+
MultiStepParamScheduler,
|
| 12 |
+
StepWithFixedGammaParamScheduler,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from detectron2.config import CfgNode
|
| 16 |
+
from detectron2.utils.env import TORCH_VERSION
|
| 17 |
+
|
| 18 |
+
from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler
|
| 19 |
+
|
| 20 |
+
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
|
| 21 |
+
_GradientClipper = Callable[[_GradientClipperInput], None]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GradientClipType(Enum):
|
| 25 |
+
VALUE = "value"
|
| 26 |
+
NORM = "norm"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
|
| 30 |
+
"""
|
| 31 |
+
Creates gradient clipping closure to clip by value or by norm,
|
| 32 |
+
according to the provided config.
|
| 33 |
+
"""
|
| 34 |
+
cfg = copy.deepcopy(cfg)
|
| 35 |
+
|
| 36 |
+
def clip_grad_norm(p: _GradientClipperInput):
|
| 37 |
+
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
|
| 38 |
+
|
| 39 |
+
def clip_grad_value(p: _GradientClipperInput):
|
| 40 |
+
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
|
| 41 |
+
|
| 42 |
+
_GRADIENT_CLIP_TYPE_TO_CLIPPER = {
|
| 43 |
+
GradientClipType.VALUE: clip_grad_value,
|
| 44 |
+
GradientClipType.NORM: clip_grad_norm,
|
| 45 |
+
}
|
| 46 |
+
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _generate_optimizer_class_with_gradient_clipping(
|
| 50 |
+
optimizer: Type[torch.optim.Optimizer],
|
| 51 |
+
*,
|
| 52 |
+
per_param_clipper: Optional[_GradientClipper] = None,
|
| 53 |
+
global_clipper: Optional[_GradientClipper] = None,
|
| 54 |
+
) -> Type[torch.optim.Optimizer]:
|
| 55 |
+
"""
|
| 56 |
+
Dynamically creates a new type that inherits the type of a given instance
|
| 57 |
+
and overrides the `step` method to add gradient clipping
|
| 58 |
+
"""
|
| 59 |
+
assert (
|
| 60 |
+
per_param_clipper is None or global_clipper is None
|
| 61 |
+
), "Not allowed to use both per-parameter clipping and global clipping"
|
| 62 |
+
|
| 63 |
+
def optimizer_wgc_step(self, closure=None):
|
| 64 |
+
if per_param_clipper is not None:
|
| 65 |
+
for group in self.param_groups:
|
| 66 |
+
for p in group["params"]:
|
| 67 |
+
per_param_clipper(p)
|
| 68 |
+
else:
|
| 69 |
+
# global clipper for future use with detr
|
| 70 |
+
# (https://github.com/facebookresearch/detr/pull/287)
|
| 71 |
+
all_params = itertools.chain(*[g["params"] for g in self.param_groups])
|
| 72 |
+
global_clipper(all_params)
|
| 73 |
+
super(type(self), self).step(closure)
|
| 74 |
+
|
| 75 |
+
OptimizerWithGradientClip = type(
|
| 76 |
+
optimizer.__name__ + "WithGradientClip",
|
| 77 |
+
(optimizer,),
|
| 78 |
+
{"step": optimizer_wgc_step},
|
| 79 |
+
)
|
| 80 |
+
return OptimizerWithGradientClip
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def maybe_add_gradient_clipping(
|
| 84 |
+
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
| 85 |
+
) -> Type[torch.optim.Optimizer]:
|
| 86 |
+
"""
|
| 87 |
+
If gradient clipping is enabled through config options, wraps the existing
|
| 88 |
+
optimizer type to become a new dynamically created class OptimizerWithGradientClip
|
| 89 |
+
that inherits the given optimizer and overrides the `step` method to
|
| 90 |
+
include gradient clipping.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
cfg: CfgNode, configuration options
|
| 94 |
+
optimizer: type. A subclass of torch.optim.Optimizer
|
| 95 |
+
|
| 96 |
+
Return:
|
| 97 |
+
type: either the input `optimizer` (if gradient clipping is disabled), or
|
| 98 |
+
a subclass of it with gradient clipping included in the `step` method.
|
| 99 |
+
"""
|
| 100 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
|
| 101 |
+
return optimizer
|
| 102 |
+
if isinstance(optimizer, torch.optim.Optimizer):
|
| 103 |
+
optimizer_type = type(optimizer)
|
| 104 |
+
else:
|
| 105 |
+
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
|
| 106 |
+
optimizer_type = optimizer
|
| 107 |
+
|
| 108 |
+
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
|
| 109 |
+
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
|
| 110 |
+
optimizer_type, per_param_clipper=grad_clipper
|
| 111 |
+
)
|
| 112 |
+
if isinstance(optimizer, torch.optim.Optimizer):
|
| 113 |
+
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
|
| 114 |
+
return optimizer
|
| 115 |
+
else:
|
| 116 |
+
return OptimizerWithGradientClip
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
| 120 |
+
"""
|
| 121 |
+
Build an optimizer from config.
|
| 122 |
+
"""
|
| 123 |
+
params = get_default_optimizer_params(
|
| 124 |
+
model,
|
| 125 |
+
base_lr=cfg.SOLVER.BASE_LR,
|
| 126 |
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
| 127 |
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
| 128 |
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
| 129 |
+
)
|
| 130 |
+
sgd_args = {
|
| 131 |
+
"params": params,
|
| 132 |
+
"lr": cfg.SOLVER.BASE_LR,
|
| 133 |
+
"momentum": cfg.SOLVER.MOMENTUM,
|
| 134 |
+
"nesterov": cfg.SOLVER.NESTEROV,
|
| 135 |
+
"weight_decay": cfg.SOLVER.WEIGHT_DECAY,
|
| 136 |
+
}
|
| 137 |
+
if TORCH_VERSION >= (1, 12):
|
| 138 |
+
sgd_args["foreach"] = True
|
| 139 |
+
return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_default_optimizer_params(
|
| 143 |
+
model: torch.nn.Module,
|
| 144 |
+
base_lr: Optional[float] = None,
|
| 145 |
+
weight_decay: Optional[float] = None,
|
| 146 |
+
weight_decay_norm: Optional[float] = None,
|
| 147 |
+
bias_lr_factor: Optional[float] = 1.0,
|
| 148 |
+
weight_decay_bias: Optional[float] = None,
|
| 149 |
+
lr_factor_func: Optional[Callable] = None,
|
| 150 |
+
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
| 151 |
+
) -> List[Dict[str, Any]]:
|
| 152 |
+
"""
|
| 153 |
+
Get default param list for optimizer, with support for a few types of
|
| 154 |
+
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
|
| 158 |
+
weight_decay: weight decay for every group by default. Can be omitted to use the one
|
| 159 |
+
in optimizer.
|
| 160 |
+
weight_decay_norm: override weight decay for params in normalization layers
|
| 161 |
+
bias_lr_factor: multiplier of lr for bias parameters.
|
| 162 |
+
weight_decay_bias: override weight decay for bias parameters.
|
| 163 |
+
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
|
| 164 |
+
corresponding lr decay rate. Note that setting this option requires
|
| 165 |
+
also setting ``base_lr``.
|
| 166 |
+
overrides: if not `None`, provides values for optimizer hyperparameters
|
| 167 |
+
(LR, weight decay) for module parameters with a given name; e.g.
|
| 168 |
+
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
| 169 |
+
weight decay values for all module parameters named `embedding`.
|
| 170 |
+
|
| 171 |
+
For common detection models, ``weight_decay_norm`` is the only option
|
| 172 |
+
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
| 173 |
+
from Detectron1 that are not found useful.
|
| 174 |
+
|
| 175 |
+
Example:
|
| 176 |
+
::
|
| 177 |
+
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
|
| 178 |
+
lr=0.01, weight_decay=1e-4, momentum=0.9)
|
| 179 |
+
"""
|
| 180 |
+
if overrides is None:
|
| 181 |
+
overrides = {}
|
| 182 |
+
defaults = {}
|
| 183 |
+
if base_lr is not None:
|
| 184 |
+
defaults["lr"] = base_lr
|
| 185 |
+
if weight_decay is not None:
|
| 186 |
+
defaults["weight_decay"] = weight_decay
|
| 187 |
+
bias_overrides = {}
|
| 188 |
+
if bias_lr_factor is not None and bias_lr_factor != 1.0:
|
| 189 |
+
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
|
| 190 |
+
# exactly the same as regular weights.
|
| 191 |
+
if base_lr is None:
|
| 192 |
+
raise ValueError("bias_lr_factor requires base_lr")
|
| 193 |
+
bias_overrides["lr"] = base_lr * bias_lr_factor
|
| 194 |
+
if weight_decay_bias is not None:
|
| 195 |
+
bias_overrides["weight_decay"] = weight_decay_bias
|
| 196 |
+
if len(bias_overrides):
|
| 197 |
+
if "bias" in overrides:
|
| 198 |
+
raise ValueError("Conflicting overrides for 'bias'")
|
| 199 |
+
overrides["bias"] = bias_overrides
|
| 200 |
+
if lr_factor_func is not None:
|
| 201 |
+
if base_lr is None:
|
| 202 |
+
raise ValueError("lr_factor_func requires base_lr")
|
| 203 |
+
norm_module_types = (
|
| 204 |
+
torch.nn.BatchNorm1d,
|
| 205 |
+
torch.nn.BatchNorm2d,
|
| 206 |
+
torch.nn.BatchNorm3d,
|
| 207 |
+
torch.nn.SyncBatchNorm,
|
| 208 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
| 209 |
+
torch.nn.GroupNorm,
|
| 210 |
+
torch.nn.InstanceNorm1d,
|
| 211 |
+
torch.nn.InstanceNorm2d,
|
| 212 |
+
torch.nn.InstanceNorm3d,
|
| 213 |
+
torch.nn.LayerNorm,
|
| 214 |
+
torch.nn.LocalResponseNorm,
|
| 215 |
+
)
|
| 216 |
+
params: List[Dict[str, Any]] = []
|
| 217 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
| 218 |
+
for module_name, module in model.named_modules():
|
| 219 |
+
for module_param_name, value in module.named_parameters(recurse=False):
|
| 220 |
+
if not value.requires_grad:
|
| 221 |
+
continue
|
| 222 |
+
# Avoid duplicating parameters
|
| 223 |
+
if value in memo:
|
| 224 |
+
continue
|
| 225 |
+
memo.add(value)
|
| 226 |
+
|
| 227 |
+
hyperparams = copy.copy(defaults)
|
| 228 |
+
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
|
| 229 |
+
hyperparams["weight_decay"] = weight_decay_norm
|
| 230 |
+
if lr_factor_func is not None:
|
| 231 |
+
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
|
| 232 |
+
|
| 233 |
+
hyperparams.update(overrides.get(module_param_name, {}))
|
| 234 |
+
params.append({"params": [value], **hyperparams})
|
| 235 |
+
return reduce_param_groups(params)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 239 |
+
# Transform parameter groups into per-parameter structure.
|
| 240 |
+
# Later items in `params` can overwrite parameters set in previous items.
|
| 241 |
+
ret = defaultdict(dict)
|
| 242 |
+
for item in params:
|
| 243 |
+
assert "params" in item
|
| 244 |
+
cur_params = {x: y for x, y in item.items() if x != "params" and x != "param_names"}
|
| 245 |
+
if "param_names" in item:
|
| 246 |
+
for param_name, param in zip(item["param_names"], item["params"]):
|
| 247 |
+
ret[param].update({"param_names": [param_name], "params": [param], **cur_params})
|
| 248 |
+
else:
|
| 249 |
+
for param in item["params"]:
|
| 250 |
+
ret[param].update({"params": [param], **cur_params})
|
| 251 |
+
return list(ret.values())
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 255 |
+
# Reorganize the parameter groups and merge duplicated groups.
|
| 256 |
+
# The number of parameter groups needs to be as small as possible in order
|
| 257 |
+
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
|
| 258 |
+
# of using a parameter_group per single parameter, we reorganize the
|
| 259 |
+
# parameter groups and merge duplicated groups. This approach speeds
|
| 260 |
+
# up multi-tensor optimizer significantly.
|
| 261 |
+
params = _expand_param_groups(params)
|
| 262 |
+
groups = defaultdict(list) # re-group all parameter groups by their hyperparams
|
| 263 |
+
for item in params:
|
| 264 |
+
cur_params = tuple((x, y) for x, y in item.items() if x != "params" and x != "param_names")
|
| 265 |
+
groups[cur_params].append({"params": item["params"]})
|
| 266 |
+
if "param_names" in item:
|
| 267 |
+
groups[cur_params][-1]["param_names"] = item["param_names"]
|
| 268 |
+
|
| 269 |
+
ret = []
|
| 270 |
+
for param_keys, param_values in groups.items():
|
| 271 |
+
cur = {kv[0]: kv[1] for kv in param_keys}
|
| 272 |
+
cur["params"] = list(
|
| 273 |
+
itertools.chain.from_iterable([params["params"] for params in param_values])
|
| 274 |
+
)
|
| 275 |
+
if len(param_values) > 0 and "param_names" in param_values[0]:
|
| 276 |
+
cur["param_names"] = list(
|
| 277 |
+
itertools.chain.from_iterable([params["param_names"] for params in param_values])
|
| 278 |
+
)
|
| 279 |
+
ret.append(cur)
|
| 280 |
+
return ret
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
|
| 284 |
+
"""
|
| 285 |
+
Build a LR scheduler from config.
|
| 286 |
+
"""
|
| 287 |
+
name = cfg.SOLVER.LR_SCHEDULER_NAME
|
| 288 |
+
|
| 289 |
+
if name == "WarmupMultiStepLR":
|
| 290 |
+
steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
|
| 291 |
+
if len(steps) != len(cfg.SOLVER.STEPS):
|
| 292 |
+
logger = logging.getLogger(__name__)
|
| 293 |
+
logger.warning(
|
| 294 |
+
"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
|
| 295 |
+
"These values will be ignored."
|
| 296 |
+
)
|
| 297 |
+
sched = MultiStepParamScheduler(
|
| 298 |
+
values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)],
|
| 299 |
+
milestones=steps,
|
| 300 |
+
num_updates=cfg.SOLVER.MAX_ITER,
|
| 301 |
+
)
|
| 302 |
+
elif name == "WarmupCosineLR":
|
| 303 |
+
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
|
| 304 |
+
assert end_value >= 0.0 and end_value <= 1.0, end_value
|
| 305 |
+
sched = CosineParamScheduler(1, end_value)
|
| 306 |
+
elif name == "WarmupStepWithFixedGammaLR":
|
| 307 |
+
sched = StepWithFixedGammaParamScheduler(
|
| 308 |
+
base_value=1.0,
|
| 309 |
+
gamma=cfg.SOLVER.GAMMA,
|
| 310 |
+
num_decays=cfg.SOLVER.NUM_DECAYS,
|
| 311 |
+
num_updates=cfg.SOLVER.MAX_ITER,
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
raise ValueError("Unknown LR scheduler: {}".format(name))
|
| 315 |
+
|
| 316 |
+
sched = WarmupParamScheduler(
|
| 317 |
+
sched,
|
| 318 |
+
cfg.SOLVER.WARMUP_FACTOR,
|
| 319 |
+
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
|
| 320 |
+
cfg.SOLVER.WARMUP_METHOD,
|
| 321 |
+
cfg.SOLVER.RESCALE_INTERVAL,
|
| 322 |
+
)
|
| 323 |
+
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
|
CatVTON/detectron2/solver/lr_scheduler.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
from bisect import bisect_right
|
| 5 |
+
from typing import List
|
| 6 |
+
import torch
|
| 7 |
+
from fvcore.common.param_scheduler import (
|
| 8 |
+
CompositeParamScheduler,
|
| 9 |
+
ConstantParamScheduler,
|
| 10 |
+
LinearParamScheduler,
|
| 11 |
+
ParamScheduler,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 16 |
+
except ImportError:
|
| 17 |
+
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class WarmupParamScheduler(CompositeParamScheduler):
|
| 23 |
+
"""
|
| 24 |
+
Add an initial warmup stage to another scheduler.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
scheduler: ParamScheduler,
|
| 30 |
+
warmup_factor: float,
|
| 31 |
+
warmup_length: float,
|
| 32 |
+
warmup_method: str = "linear",
|
| 33 |
+
rescale_interval: bool = False,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
scheduler: warmup will be added at the beginning of this scheduler
|
| 38 |
+
warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001
|
| 39 |
+
warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
|
| 40 |
+
training, e.g. 0.01
|
| 41 |
+
warmup_method: one of "linear" or "constant"
|
| 42 |
+
rescale_interval: whether we will rescale the interval of the scheduler after
|
| 43 |
+
warmup
|
| 44 |
+
"""
|
| 45 |
+
# the value to reach when warmup ends
|
| 46 |
+
end_value = scheduler(0.0) if rescale_interval else scheduler(warmup_length)
|
| 47 |
+
start_value = warmup_factor * scheduler(0.0)
|
| 48 |
+
if warmup_method == "constant":
|
| 49 |
+
warmup = ConstantParamScheduler(start_value)
|
| 50 |
+
elif warmup_method == "linear":
|
| 51 |
+
warmup = LinearParamScheduler(start_value, end_value)
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Unknown warmup method: {}".format(warmup_method))
|
| 54 |
+
super().__init__(
|
| 55 |
+
[warmup, scheduler],
|
| 56 |
+
interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"],
|
| 57 |
+
lengths=[warmup_length, 1 - warmup_length],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LRMultiplier(LRScheduler):
|
| 62 |
+
"""
|
| 63 |
+
A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
|
| 64 |
+
learning rate of each param in the optimizer.
|
| 65 |
+
Every step, the learning rate of each parameter becomes its initial value
|
| 66 |
+
multiplied by the output of the given :class:`ParamScheduler`.
|
| 67 |
+
|
| 68 |
+
The absolute learning rate value of each parameter can be different.
|
| 69 |
+
This scheduler can be used as long as the relative scale among them do
|
| 70 |
+
not change during training.
|
| 71 |
+
|
| 72 |
+
Examples:
|
| 73 |
+
::
|
| 74 |
+
LRMultiplier(
|
| 75 |
+
opt,
|
| 76 |
+
WarmupParamScheduler(
|
| 77 |
+
MultiStepParamScheduler(
|
| 78 |
+
[1, 0.1, 0.01],
|
| 79 |
+
milestones=[60000, 80000],
|
| 80 |
+
num_updates=90000,
|
| 81 |
+
), 0.001, 100 / 90000
|
| 82 |
+
),
|
| 83 |
+
max_iter=90000
|
| 84 |
+
)
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
# NOTES: in the most general case, every LR can use its own scheduler.
|
| 88 |
+
# Supporting this requires interaction with the optimizer when its parameter
|
| 89 |
+
# group is initialized. For example, classyvision implements its own optimizer
|
| 90 |
+
# that allows different schedulers for every parameter group.
|
| 91 |
+
# To avoid this complexity, we use this class to support the most common cases
|
| 92 |
+
# where the relative scale among all LRs stay unchanged during training. In this
|
| 93 |
+
# case we only need a total of one scheduler that defines the relative LR multiplier.
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
optimizer: torch.optim.Optimizer,
|
| 98 |
+
multiplier: ParamScheduler,
|
| 99 |
+
max_iter: int,
|
| 100 |
+
last_iter: int = -1,
|
| 101 |
+
):
|
| 102 |
+
"""
|
| 103 |
+
Args:
|
| 104 |
+
optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``.
|
| 105 |
+
``last_iter`` is the same as ``last_epoch``.
|
| 106 |
+
multiplier: a fvcore ParamScheduler that defines the multiplier on
|
| 107 |
+
every LR of the optimizer
|
| 108 |
+
max_iter: the total number of training iterations
|
| 109 |
+
"""
|
| 110 |
+
if not isinstance(multiplier, ParamScheduler):
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"_LRMultiplier(multiplier=) must be an instance of fvcore "
|
| 113 |
+
f"ParamScheduler. Got {multiplier} instead."
|
| 114 |
+
)
|
| 115 |
+
self._multiplier = multiplier
|
| 116 |
+
self._max_iter = max_iter
|
| 117 |
+
super().__init__(optimizer, last_epoch=last_iter)
|
| 118 |
+
|
| 119 |
+
def state_dict(self):
|
| 120 |
+
# fvcore schedulers are stateless. Only keep pytorch scheduler states
|
| 121 |
+
return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
|
| 122 |
+
|
| 123 |
+
def get_lr(self) -> List[float]:
|
| 124 |
+
multiplier = self._multiplier(self.last_epoch / self._max_iter)
|
| 125 |
+
return [base_lr * multiplier for base_lr in self.base_lrs]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
Content below is no longer needed!
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
|
| 133 |
+
# only on epoch boundaries. We typically use iteration based schedules instead.
|
| 134 |
+
# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
|
| 135 |
+
# "iteration" instead.
|
| 136 |
+
|
| 137 |
+
# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
|
| 138 |
+
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class WarmupMultiStepLR(LRScheduler):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
optimizer: torch.optim.Optimizer,
|
| 145 |
+
milestones: List[int],
|
| 146 |
+
gamma: float = 0.1,
|
| 147 |
+
warmup_factor: float = 0.001,
|
| 148 |
+
warmup_iters: int = 1000,
|
| 149 |
+
warmup_method: str = "linear",
|
| 150 |
+
last_epoch: int = -1,
|
| 151 |
+
):
|
| 152 |
+
logger.warning(
|
| 153 |
+
"WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
|
| 154 |
+
)
|
| 155 |
+
if not list(milestones) == sorted(milestones):
|
| 156 |
+
raise ValueError(
|
| 157 |
+
"Milestones should be a list of" " increasing integers. Got {}", milestones
|
| 158 |
+
)
|
| 159 |
+
self.milestones = milestones
|
| 160 |
+
self.gamma = gamma
|
| 161 |
+
self.warmup_factor = warmup_factor
|
| 162 |
+
self.warmup_iters = warmup_iters
|
| 163 |
+
self.warmup_method = warmup_method
|
| 164 |
+
super().__init__(optimizer, last_epoch)
|
| 165 |
+
|
| 166 |
+
def get_lr(self) -> List[float]:
|
| 167 |
+
warmup_factor = _get_warmup_factor_at_iter(
|
| 168 |
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
| 169 |
+
)
|
| 170 |
+
return [
|
| 171 |
+
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
| 172 |
+
for base_lr in self.base_lrs
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
def _compute_values(self) -> List[float]:
|
| 176 |
+
# The new interface
|
| 177 |
+
return self.get_lr()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class WarmupCosineLR(LRScheduler):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
optimizer: torch.optim.Optimizer,
|
| 184 |
+
max_iters: int,
|
| 185 |
+
warmup_factor: float = 0.001,
|
| 186 |
+
warmup_iters: int = 1000,
|
| 187 |
+
warmup_method: str = "linear",
|
| 188 |
+
last_epoch: int = -1,
|
| 189 |
+
):
|
| 190 |
+
logger.warning(
|
| 191 |
+
"WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
|
| 192 |
+
)
|
| 193 |
+
self.max_iters = max_iters
|
| 194 |
+
self.warmup_factor = warmup_factor
|
| 195 |
+
self.warmup_iters = warmup_iters
|
| 196 |
+
self.warmup_method = warmup_method
|
| 197 |
+
super().__init__(optimizer, last_epoch)
|
| 198 |
+
|
| 199 |
+
def get_lr(self) -> List[float]:
|
| 200 |
+
warmup_factor = _get_warmup_factor_at_iter(
|
| 201 |
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
| 202 |
+
)
|
| 203 |
+
# Different definitions of half-cosine with warmup are possible. For
|
| 204 |
+
# simplicity we multiply the standard half-cosine schedule by the warmup
|
| 205 |
+
# factor. An alternative is to start the period of the cosine at warmup_iters
|
| 206 |
+
# instead of at 0. In the case that warmup_iters << max_iters the two are
|
| 207 |
+
# very close to each other.
|
| 208 |
+
return [
|
| 209 |
+
base_lr
|
| 210 |
+
* warmup_factor
|
| 211 |
+
* 0.5
|
| 212 |
+
* (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters))
|
| 213 |
+
for base_lr in self.base_lrs
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
def _compute_values(self) -> List[float]:
|
| 217 |
+
# The new interface
|
| 218 |
+
return self.get_lr()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _get_warmup_factor_at_iter(
|
| 222 |
+
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
| 223 |
+
) -> float:
|
| 224 |
+
"""
|
| 225 |
+
Return the learning rate warmup factor at a specific iteration.
|
| 226 |
+
See :paper:`ImageNet in 1h` for more details.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
method (str): warmup method; either "constant" or "linear".
|
| 230 |
+
iter (int): iteration at which to calculate the warmup factor.
|
| 231 |
+
warmup_iters (int): the number of warmup iterations.
|
| 232 |
+
warmup_factor (float): the base warmup factor (the meaning changes according
|
| 233 |
+
to the method used).
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
float: the effective warmup factor at the given iteration.
|
| 237 |
+
"""
|
| 238 |
+
if iter >= warmup_iters:
|
| 239 |
+
return 1.0
|
| 240 |
+
|
| 241 |
+
if method == "constant":
|
| 242 |
+
return warmup_factor
|
| 243 |
+
elif method == "linear":
|
| 244 |
+
alpha = iter / warmup_iters
|
| 245 |
+
return warmup_factor * (1 - alpha) + alpha
|
| 246 |
+
else:
|
| 247 |
+
raise ValueError("Unknown warmup method: {}".format(method))
|