b3h-young123 commited on
Commit
fb30010
·
verified ·
1 Parent(s): e5ae10c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CatVTON/densepose/data/datasets/__init__.py +7 -0
  2. CatVTON/densepose/data/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  3. CatVTON/densepose/data/datasets/__pycache__/builtin.cpython-39.pyc +0 -0
  4. CatVTON/densepose/data/datasets/__pycache__/chimpnsee.cpython-39.pyc +0 -0
  5. CatVTON/densepose/data/datasets/__pycache__/coco.cpython-39.pyc +0 -0
  6. CatVTON/densepose/data/datasets/__pycache__/dataset_type.cpython-39.pyc +0 -0
  7. CatVTON/densepose/data/datasets/__pycache__/lvis.cpython-39.pyc +0 -0
  8. CatVTON/densepose/data/datasets/builtin.py +18 -0
  9. CatVTON/densepose/data/datasets/chimpnsee.py +31 -0
  10. CatVTON/densepose/data/datasets/coco.py +434 -0
  11. CatVTON/densepose/data/datasets/dataset_type.py +13 -0
  12. CatVTON/densepose/data/datasets/lvis.py +259 -0
  13. CatVTON/densepose/data/samplers/__pycache__/__init__.cpython-39.pyc +0 -0
  14. CatVTON/densepose/data/samplers/__pycache__/densepose_base.cpython-39.pyc +0 -0
  15. CatVTON/densepose/data/samplers/__pycache__/densepose_confidence_based.cpython-39.pyc +0 -0
  16. CatVTON/densepose/data/samplers/__pycache__/densepose_cse_base.cpython-39.pyc +0 -0
  17. CatVTON/densepose/data/samplers/__pycache__/densepose_cse_confidence_based.cpython-39.pyc +0 -0
  18. CatVTON/densepose/data/samplers/__pycache__/densepose_cse_uniform.cpython-39.pyc +0 -0
  19. CatVTON/densepose/data/samplers/__pycache__/densepose_uniform.cpython-39.pyc +0 -0
  20. CatVTON/densepose/data/samplers/__pycache__/mask_from_densepose.cpython-39.pyc +0 -0
  21. CatVTON/densepose/data/samplers/__pycache__/prediction_to_gt.cpython-39.pyc +0 -0
  22. CatVTON/densepose/data/samplers/densepose_base.py +205 -0
  23. CatVTON/densepose/data/samplers/densepose_confidence_based.py +110 -0
  24. CatVTON/densepose/data/samplers/densepose_cse_uniform.py +14 -0
  25. CatVTON/densepose/data/samplers/mask_from_densepose.py +30 -0
  26. CatVTON/densepose/data/samplers/prediction_to_gt.py +100 -0
  27. CatVTON/densepose/data/transform/__init__.py +5 -0
  28. CatVTON/densepose/data/transform/__pycache__/__init__.cpython-39.pyc +0 -0
  29. CatVTON/densepose/data/transform/__pycache__/image.cpython-39.pyc +0 -0
  30. CatVTON/densepose/data/transform/image.py +41 -0
  31. CatVTON/detectron2/__init__.py +10 -0
  32. CatVTON/detectron2/checkpoint/__init__.py +10 -0
  33. CatVTON/detectron2/checkpoint/c2_model_loading.py +406 -0
  34. CatVTON/detectron2/checkpoint/catalog.py +115 -0
  35. CatVTON/detectron2/checkpoint/detection_checkpoint.py +143 -0
  36. CatVTON/detectron2/engine/__init__.py +19 -0
  37. CatVTON/detectron2/engine/defaults.py +719 -0
  38. CatVTON/detectron2/engine/hooks.py +690 -0
  39. CatVTON/detectron2/engine/launch.py +123 -0
  40. CatVTON/detectron2/engine/train_loop.py +530 -0
  41. CatVTON/detectron2/modeling/__init__.py +64 -0
  42. CatVTON/detectron2/modeling/anchor_generator.py +390 -0
  43. CatVTON/detectron2/modeling/box_regression.py +369 -0
  44. CatVTON/detectron2/modeling/matcher.py +127 -0
  45. CatVTON/detectron2/modeling/poolers.py +263 -0
  46. CatVTON/detectron2/projects/README.md +2 -0
  47. CatVTON/detectron2/projects/__init__.py +34 -0
  48. CatVTON/detectron2/solver/__init__.py +11 -0
  49. CatVTON/detectron2/solver/build.py +323 -0
  50. CatVTON/detectron2/solver/lr_scheduler.py +247 -0
CatVTON/densepose/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin # ensure the builtin datasets are registered
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
CatVTON/densepose/data/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (384 Bytes). View file
 
CatVTON/densepose/data/datasets/__pycache__/builtin.cpython-39.pyc ADDED
Binary file (575 Bytes). View file
 
CatVTON/densepose/data/datasets/__pycache__/chimpnsee.cpython-39.pyc ADDED
Binary file (1.03 kB). View file
 
CatVTON/densepose/data/datasets/__pycache__/coco.cpython-39.pyc ADDED
Binary file (11.7 kB). View file
 
CatVTON/densepose/data/datasets/__pycache__/dataset_type.cpython-39.pyc ADDED
Binary file (499 Bytes). View file
 
CatVTON/densepose/data/datasets/__pycache__/lvis.cpython-39.pyc ADDED
Binary file (7.83 kB). View file
 
CatVTON/densepose/data/datasets/builtin.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .chimpnsee import register_dataset as register_chimpnsee_dataset
5
+ from .coco import BASE_DATASETS as BASE_COCO_DATASETS
6
+ from .coco import DATASETS as COCO_DATASETS
7
+ from .coco import register_datasets as register_coco_datasets
8
+ from .lvis import DATASETS as LVIS_DATASETS
9
+ from .lvis import register_datasets as register_lvis_datasets
10
+
11
+ DEFAULT_DATASETS_ROOT = "datasets"
12
+
13
+
14
+ register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
15
+ register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
16
+ register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
17
+
18
+ register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
CatVTON/densepose/data/datasets/chimpnsee.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Optional
6
+
7
+ from detectron2.data import DatasetCatalog, MetadataCatalog
8
+
9
+ from ..utils import maybe_prepend_base_path
10
+ from .dataset_type import DatasetType
11
+
12
+ CHIMPNSEE_DATASET_NAME = "chimpnsee"
13
+
14
+
15
+ def register_dataset(datasets_root: Optional[str] = None) -> None:
16
+ def empty_load_callback():
17
+ pass
18
+
19
+ video_list_fpath = maybe_prepend_base_path(
20
+ datasets_root,
21
+ "chimpnsee/cdna.eva.mpg.de/video_list.txt",
22
+ )
23
+ video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de")
24
+
25
+ DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback)
26
+ MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set(
27
+ dataset_type=DatasetType.VIDEO_LIST,
28
+ video_list_fpath=video_list_fpath,
29
+ video_base_path=video_base_path,
30
+ category="chimpanzee",
31
+ )
CatVTON/densepose/data/datasets/coco.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import contextlib
5
+ import io
6
+ import logging
7
+ import os
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+ from fvcore.common.timer import Timer
12
+
13
+ from detectron2.data import DatasetCatalog, MetadataCatalog
14
+ from detectron2.structures import BoxMode
15
+ from detectron2.utils.file_io import PathManager
16
+
17
+ from ..utils import maybe_prepend_base_path
18
+
19
+ DENSEPOSE_MASK_KEY = "dp_masks"
20
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"]
21
+ DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"]
22
+ DENSEPOSE_ALL_POSSIBLE_KEYS = set(
23
+ DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY]
24
+ )
25
+ DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/"
26
+
27
+
28
+ @dataclass
29
+ class CocoDatasetInfo:
30
+ name: str
31
+ images_root: str
32
+ annotations_fpath: str
33
+
34
+
35
+ DATASETS = [
36
+ CocoDatasetInfo(
37
+ name="densepose_coco_2014_train",
38
+ images_root="coco/train2014",
39
+ annotations_fpath="coco/annotations/densepose_train2014.json",
40
+ ),
41
+ CocoDatasetInfo(
42
+ name="densepose_coco_2014_minival",
43
+ images_root="coco/val2014",
44
+ annotations_fpath="coco/annotations/densepose_minival2014.json",
45
+ ),
46
+ CocoDatasetInfo(
47
+ name="densepose_coco_2014_minival_100",
48
+ images_root="coco/val2014",
49
+ annotations_fpath="coco/annotations/densepose_minival2014_100.json",
50
+ ),
51
+ CocoDatasetInfo(
52
+ name="densepose_coco_2014_valminusminival",
53
+ images_root="coco/val2014",
54
+ annotations_fpath="coco/annotations/densepose_valminusminival2014.json",
55
+ ),
56
+ CocoDatasetInfo(
57
+ name="densepose_coco_2014_train_cse",
58
+ images_root="coco/train2014",
59
+ annotations_fpath="coco_cse/densepose_train2014_cse.json",
60
+ ),
61
+ CocoDatasetInfo(
62
+ name="densepose_coco_2014_minival_cse",
63
+ images_root="coco/val2014",
64
+ annotations_fpath="coco_cse/densepose_minival2014_cse.json",
65
+ ),
66
+ CocoDatasetInfo(
67
+ name="densepose_coco_2014_minival_100_cse",
68
+ images_root="coco/val2014",
69
+ annotations_fpath="coco_cse/densepose_minival2014_100_cse.json",
70
+ ),
71
+ CocoDatasetInfo(
72
+ name="densepose_coco_2014_valminusminival_cse",
73
+ images_root="coco/val2014",
74
+ annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json",
75
+ ),
76
+ CocoDatasetInfo(
77
+ name="densepose_chimps",
78
+ images_root="densepose_chimps/images",
79
+ annotations_fpath="densepose_chimps/densepose_chimps_densepose.json",
80
+ ),
81
+ CocoDatasetInfo(
82
+ name="densepose_chimps_cse_train",
83
+ images_root="densepose_chimps/images",
84
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json",
85
+ ),
86
+ CocoDatasetInfo(
87
+ name="densepose_chimps_cse_val",
88
+ images_root="densepose_chimps/images",
89
+ annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json",
90
+ ),
91
+ CocoDatasetInfo(
92
+ name="posetrack2017_train",
93
+ images_root="posetrack2017/posetrack_data_2017",
94
+ annotations_fpath="posetrack2017/densepose_posetrack_train2017.json",
95
+ ),
96
+ CocoDatasetInfo(
97
+ name="posetrack2017_val",
98
+ images_root="posetrack2017/posetrack_data_2017",
99
+ annotations_fpath="posetrack2017/densepose_posetrack_val2017.json",
100
+ ),
101
+ CocoDatasetInfo(
102
+ name="lvis_v05_train",
103
+ images_root="coco/train2017",
104
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json",
105
+ ),
106
+ CocoDatasetInfo(
107
+ name="lvis_v05_val",
108
+ images_root="coco/val2017",
109
+ annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json",
110
+ ),
111
+ ]
112
+
113
+
114
+ BASE_DATASETS = [
115
+ CocoDatasetInfo(
116
+ name="base_coco_2017_train",
117
+ images_root="coco/train2017",
118
+ annotations_fpath="coco/annotations/instances_train2017.json",
119
+ ),
120
+ CocoDatasetInfo(
121
+ name="base_coco_2017_val",
122
+ images_root="coco/val2017",
123
+ annotations_fpath="coco/annotations/instances_val2017.json",
124
+ ),
125
+ CocoDatasetInfo(
126
+ name="base_coco_2017_val_100",
127
+ images_root="coco/val2017",
128
+ annotations_fpath="coco/annotations/instances_val2017_100.json",
129
+ ),
130
+ ]
131
+
132
+
133
+ def get_metadata(base_path: Optional[str]) -> Dict[str, Any]:
134
+ """
135
+ Returns metadata associated with COCO DensePose datasets
136
+
137
+ Args:
138
+ base_path: Optional[str]
139
+ Base path used to load metadata from
140
+
141
+ Returns:
142
+ Dict[str, Any]
143
+ Metadata in the form of a dictionary
144
+ """
145
+ meta = {
146
+ "densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"),
147
+ "densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"),
148
+ "densepose_smpl_subdiv_transform": maybe_prepend_base_path(
149
+ base_path,
150
+ "SMPL_SUBDIV_TRANSFORM.mat",
151
+ ),
152
+ }
153
+ return meta
154
+
155
+
156
+ def _load_coco_annotations(json_file: str):
157
+ """
158
+ Load COCO annotations from a JSON file
159
+
160
+ Args:
161
+ json_file: str
162
+ Path to the file to load annotations from
163
+ Returns:
164
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
165
+ data
166
+ """
167
+ from pycocotools.coco import COCO
168
+
169
+ logger = logging.getLogger(__name__)
170
+ timer = Timer()
171
+ with contextlib.redirect_stdout(io.StringIO()):
172
+ coco_api = COCO(json_file)
173
+ if timer.seconds() > 1:
174
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
175
+ return coco_api
176
+
177
+
178
+ def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]):
179
+ meta = MetadataCatalog.get(dataset_name)
180
+ meta.categories = {c["id"]: c["name"] for c in categories}
181
+ logger = logging.getLogger(__name__)
182
+ logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories))
183
+
184
+
185
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]):
186
+ if "minival" in json_file:
187
+ # Skip validation on COCO2014 valminusminival and minival annotations
188
+ # The ratio of buggy annotations there is tiny and does not affect accuracy
189
+ # Therefore we explicitly white-list them
190
+ return
191
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
192
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
193
+ json_file
194
+ )
195
+
196
+
197
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
198
+ if "bbox" not in ann_dict:
199
+ return
200
+ obj["bbox"] = ann_dict["bbox"]
201
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
202
+
203
+
204
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
205
+ if "segmentation" not in ann_dict:
206
+ return
207
+ segm = ann_dict["segmentation"]
208
+ if not isinstance(segm, dict):
209
+ # filter out invalid polygons (< 3 points)
210
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
211
+ if len(segm) == 0:
212
+ return
213
+ obj["segmentation"] = segm
214
+
215
+
216
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
217
+ if "keypoints" not in ann_dict:
218
+ return
219
+ keypts = ann_dict["keypoints"] # list[int]
220
+ for idx, v in enumerate(keypts):
221
+ if idx % 3 != 2:
222
+ # COCO's segmentation coordinates are floating points in [0, H or W],
223
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
224
+ # Therefore we assume the coordinates are "pixel indices" and
225
+ # add 0.5 to convert to floating point coordinates.
226
+ keypts[idx] = v + 0.5
227
+ obj["keypoints"] = keypts
228
+
229
+
230
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]):
231
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
232
+ if key in ann_dict:
233
+ obj[key] = ann_dict[key]
234
+
235
+
236
+ def _combine_images_with_annotations(
237
+ dataset_name: str,
238
+ image_root: str,
239
+ img_datas: Iterable[Dict[str, Any]],
240
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
241
+ ):
242
+
243
+ ann_keys = ["iscrowd", "category_id"]
244
+ dataset_dicts = []
245
+ contains_video_frame_info = False
246
+
247
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
248
+ record = {}
249
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
250
+ record["height"] = img_dict["height"]
251
+ record["width"] = img_dict["width"]
252
+ record["image_id"] = img_dict["id"]
253
+ record["dataset"] = dataset_name
254
+ if "frame_id" in img_dict:
255
+ record["frame_id"] = img_dict["frame_id"]
256
+ record["video_id"] = img_dict.get("vid_id", None)
257
+ contains_video_frame_info = True
258
+ objs = []
259
+ for ann_dict in ann_dicts:
260
+ assert ann_dict["image_id"] == record["image_id"]
261
+ assert ann_dict.get("ignore", 0) == 0
262
+ obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict}
263
+ _maybe_add_bbox(obj, ann_dict)
264
+ _maybe_add_segm(obj, ann_dict)
265
+ _maybe_add_keypoints(obj, ann_dict)
266
+ _maybe_add_densepose(obj, ann_dict)
267
+ objs.append(obj)
268
+ record["annotations"] = objs
269
+ dataset_dicts.append(record)
270
+ if contains_video_frame_info:
271
+ create_video_frame_mapping(dataset_name, dataset_dicts)
272
+ return dataset_dicts
273
+
274
+
275
+ def get_contiguous_id_to_category_id_map(metadata):
276
+ cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id
277
+ cont_id_2_cat_id = {}
278
+ for cat_id, cont_id in cat_id_2_cont_id.items():
279
+ if cont_id in cont_id_2_cat_id:
280
+ continue
281
+ cont_id_2_cat_id[cont_id] = cat_id
282
+ return cont_id_2_cat_id
283
+
284
+
285
+ def maybe_filter_categories_cocoapi(dataset_name, coco_api):
286
+ meta = MetadataCatalog.get(dataset_name)
287
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta)
288
+ cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id
289
+ # filter categories
290
+ cats = []
291
+ for cat in coco_api.dataset["categories"]:
292
+ cat_id = cat["id"]
293
+ if cat_id not in cat_id_2_cont_id:
294
+ continue
295
+ cont_id = cat_id_2_cont_id[cat_id]
296
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
297
+ cats.append(cat)
298
+ coco_api.dataset["categories"] = cats
299
+ # filter annotations, if multiple categories are mapped to a single
300
+ # contiguous ID, use only one category ID and map all annotations to that category ID
301
+ anns = []
302
+ for ann in coco_api.dataset["annotations"]:
303
+ cat_id = ann["category_id"]
304
+ if cat_id not in cat_id_2_cont_id:
305
+ continue
306
+ cont_id = cat_id_2_cont_id[cat_id]
307
+ ann["category_id"] = cont_id_2_cat_id[cont_id]
308
+ anns.append(ann)
309
+ coco_api.dataset["annotations"] = anns
310
+ # recreate index
311
+ coco_api.createIndex()
312
+
313
+
314
+ def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api):
315
+ meta = MetadataCatalog.get(dataset_name)
316
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
317
+ # map categories
318
+ cats = []
319
+ for cat in coco_api.dataset["categories"]:
320
+ cat_id = cat["id"]
321
+ if cat_id not in category_id_map:
322
+ continue
323
+ cat["id"] = category_id_map[cat_id]
324
+ cats.append(cat)
325
+ coco_api.dataset["categories"] = cats
326
+ # map annotation categories
327
+ anns = []
328
+ for ann in coco_api.dataset["annotations"]:
329
+ cat_id = ann["category_id"]
330
+ if cat_id not in category_id_map:
331
+ continue
332
+ ann["category_id"] = category_id_map[cat_id]
333
+ anns.append(ann)
334
+ coco_api.dataset["annotations"] = anns
335
+ # recreate index
336
+ coco_api.createIndex()
337
+
338
+
339
+ def create_video_frame_mapping(dataset_name, dataset_dicts):
340
+ mapping = defaultdict(dict)
341
+ for d in dataset_dicts:
342
+ video_id = d.get("video_id")
343
+ if video_id is None:
344
+ continue
345
+ mapping[video_id].update({d["frame_id"]: d["file_name"]})
346
+ MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping)
347
+
348
+
349
+ def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str):
350
+ """
351
+ Loads a JSON file with annotations in COCO instances format.
352
+ Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata
353
+ in a more flexible way. Postpones category mapping to a later stage to be
354
+ able to combine several datasets with different (but coherent) sets of
355
+ categories.
356
+
357
+ Args:
358
+
359
+ annotations_json_file: str
360
+ Path to the JSON file with annotations in COCO instances format.
361
+ image_root: str
362
+ directory that contains all the images
363
+ dataset_name: str
364
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
365
+ extra_annotation_keys: Optional[List[str]]
366
+ If provided, these keys are used to extract additional data from
367
+ the annotations.
368
+ """
369
+ coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file))
370
+ _add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds()))
371
+ # sort indices for reproducible results
372
+ img_ids = sorted(coco_api.imgs.keys())
373
+ # imgs is a list of dicts, each looks something like:
374
+ # {'license': 4,
375
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
376
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
377
+ # 'height': 427,
378
+ # 'width': 640,
379
+ # 'date_captured': '2013-11-17 05:57:24',
380
+ # 'id': 1268}
381
+ imgs = coco_api.loadImgs(img_ids)
382
+ logger = logging.getLogger(__name__)
383
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file))
384
+ # anns is a list[list[dict]], where each dict is an annotation
385
+ # record for an object. The inner list enumerates the objects in an image
386
+ # and the outer list enumerates over images.
387
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
388
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
389
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
390
+ return dataset_records
391
+
392
+
393
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None):
394
+ """
395
+ Registers provided COCO DensePose dataset
396
+
397
+ Args:
398
+ dataset_data: CocoDatasetInfo
399
+ Dataset data
400
+ datasets_root: Optional[str]
401
+ Datasets root folder (default: None)
402
+ """
403
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
404
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
405
+
406
+ def load_annotations():
407
+ return load_coco_json(
408
+ annotations_json_file=annotations_fpath,
409
+ image_root=images_root,
410
+ dataset_name=dataset_data.name,
411
+ )
412
+
413
+ DatasetCatalog.register(dataset_data.name, load_annotations)
414
+ MetadataCatalog.get(dataset_data.name).set(
415
+ json_file=annotations_fpath,
416
+ image_root=images_root,
417
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX)
418
+ )
419
+
420
+
421
+ def register_datasets(
422
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
423
+ ):
424
+ """
425
+ Registers provided COCO DensePose datasets
426
+
427
+ Args:
428
+ datasets_data: Iterable[CocoDatasetInfo]
429
+ An iterable of dataset datas
430
+ datasets_root: Optional[str]
431
+ Datasets root folder (default: None)
432
+ """
433
+ for dataset_data in datasets_data:
434
+ register_dataset(dataset_data, datasets_root)
CatVTON/densepose/data/datasets/dataset_type.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class DatasetType(Enum):
9
+ """
10
+ Dataset type, mostly used for datasets that contain data to bootstrap models on
11
+ """
12
+
13
+ VIDEO_LIST = "video_list"
CatVTON/densepose/data/datasets/lvis.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import logging
5
+ import os
6
+ from typing import Any, Dict, Iterable, List, Optional
7
+ from fvcore.common.timer import Timer
8
+
9
+ from detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from detectron2.data.datasets.lvis import get_lvis_instances_meta
11
+ from detectron2.structures import BoxMode
12
+ from detectron2.utils.file_io import PathManager
13
+
14
+ from ..utils import maybe_prepend_base_path
15
+ from .coco import (
16
+ DENSEPOSE_ALL_POSSIBLE_KEYS,
17
+ DENSEPOSE_METADATA_URL_PREFIX,
18
+ CocoDatasetInfo,
19
+ get_metadata,
20
+ )
21
+
22
+ DATASETS = [
23
+ CocoDatasetInfo(
24
+ name="densepose_lvis_v1_ds1_train_v1",
25
+ images_root="coco_",
26
+ annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json",
27
+ ),
28
+ CocoDatasetInfo(
29
+ name="densepose_lvis_v1_ds1_val_v1",
30
+ images_root="coco_",
31
+ annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json",
32
+ ),
33
+ CocoDatasetInfo(
34
+ name="densepose_lvis_v1_ds2_train_v1",
35
+ images_root="coco_",
36
+ annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json",
37
+ ),
38
+ CocoDatasetInfo(
39
+ name="densepose_lvis_v1_ds2_val_v1",
40
+ images_root="coco_",
41
+ annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json",
42
+ ),
43
+ CocoDatasetInfo(
44
+ name="densepose_lvis_v1_ds1_val_animals_100",
45
+ images_root="coco_",
46
+ annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json",
47
+ ),
48
+ ]
49
+
50
+
51
+ def _load_lvis_annotations(json_file: str):
52
+ """
53
+ Load COCO annotations from a JSON file
54
+
55
+ Args:
56
+ json_file: str
57
+ Path to the file to load annotations from
58
+ Returns:
59
+ Instance of `pycocotools.coco.COCO` that provides access to annotations
60
+ data
61
+ """
62
+ from lvis import LVIS
63
+
64
+ json_file = PathManager.get_local_path(json_file)
65
+ logger = logging.getLogger(__name__)
66
+ timer = Timer()
67
+ lvis_api = LVIS(json_file)
68
+ if timer.seconds() > 1:
69
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
70
+ return lvis_api
71
+
72
+
73
+ def _add_categories_metadata(dataset_name: str) -> None:
74
+ metadict = get_lvis_instances_meta(dataset_name)
75
+ categories = metadict["thing_classes"]
76
+ metadata = MetadataCatalog.get(dataset_name)
77
+ metadata.categories = {i + 1: categories[i] for i in range(len(categories))}
78
+ logger = logging.getLogger(__name__)
79
+ logger.info(f"Dataset {dataset_name} has {len(categories)} categories")
80
+
81
+
82
+ def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None:
83
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
84
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
85
+ json_file
86
+ )
87
+
88
+
89
+ def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
90
+ if "bbox" not in ann_dict:
91
+ return
92
+ obj["bbox"] = ann_dict["bbox"]
93
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
94
+
95
+
96
+ def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
97
+ if "segmentation" not in ann_dict:
98
+ return
99
+ segm = ann_dict["segmentation"]
100
+ if not isinstance(segm, dict):
101
+ # filter out invalid polygons (< 3 points)
102
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
103
+ if len(segm) == 0:
104
+ return
105
+ obj["segmentation"] = segm
106
+
107
+
108
+ def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
109
+ if "keypoints" not in ann_dict:
110
+ return
111
+ keypts = ann_dict["keypoints"] # list[int]
112
+ for idx, v in enumerate(keypts):
113
+ if idx % 3 != 2:
114
+ # COCO's segmentation coordinates are floating points in [0, H or W],
115
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
116
+ # Therefore we assume the coordinates are "pixel indices" and
117
+ # add 0.5 to convert to floating point coordinates.
118
+ keypts[idx] = v + 0.5
119
+ obj["keypoints"] = keypts
120
+
121
+
122
+ def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None:
123
+ for key in DENSEPOSE_ALL_POSSIBLE_KEYS:
124
+ if key in ann_dict:
125
+ obj[key] = ann_dict[key]
126
+
127
+
128
+ def _combine_images_with_annotations(
129
+ dataset_name: str,
130
+ image_root: str,
131
+ img_datas: Iterable[Dict[str, Any]],
132
+ ann_datas: Iterable[Iterable[Dict[str, Any]]],
133
+ ):
134
+
135
+ dataset_dicts = []
136
+
137
+ def get_file_name(img_root, img_dict):
138
+ # Determine the path including the split folder ("train2017", "val2017", "test2017") from
139
+ # the coco_url field. Example:
140
+ # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
141
+ split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
142
+ return os.path.join(img_root + split_folder, file_name)
143
+
144
+ for img_dict, ann_dicts in zip(img_datas, ann_datas):
145
+ record = {}
146
+ record["file_name"] = get_file_name(image_root, img_dict)
147
+ record["height"] = img_dict["height"]
148
+ record["width"] = img_dict["width"]
149
+ record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
150
+ record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
151
+ record["image_id"] = img_dict["id"]
152
+ record["dataset"] = dataset_name
153
+
154
+ objs = []
155
+ for ann_dict in ann_dicts:
156
+ assert ann_dict["image_id"] == record["image_id"]
157
+ obj = {}
158
+ _maybe_add_bbox(obj, ann_dict)
159
+ obj["iscrowd"] = ann_dict.get("iscrowd", 0)
160
+ obj["category_id"] = ann_dict["category_id"]
161
+ _maybe_add_segm(obj, ann_dict)
162
+ _maybe_add_keypoints(obj, ann_dict)
163
+ _maybe_add_densepose(obj, ann_dict)
164
+ objs.append(obj)
165
+ record["annotations"] = objs
166
+ dataset_dicts.append(record)
167
+ return dataset_dicts
168
+
169
+
170
+ def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str):
171
+ """
172
+ Loads a JSON file with annotations in LVIS instances format.
173
+ Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata
174
+ in a more flexible way. Postpones category mapping to a later stage to be
175
+ able to combine several datasets with different (but coherent) sets of
176
+ categories.
177
+
178
+ Args:
179
+
180
+ annotations_json_file: str
181
+ Path to the JSON file with annotations in COCO instances format.
182
+ image_root: str
183
+ directory that contains all the images
184
+ dataset_name: str
185
+ the name that identifies a dataset, e.g. "densepose_coco_2014_train"
186
+ extra_annotation_keys: Optional[List[str]]
187
+ If provided, these keys are used to extract additional data from
188
+ the annotations.
189
+ """
190
+ lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file))
191
+
192
+ _add_categories_metadata(dataset_name)
193
+
194
+ # sort indices for reproducible results
195
+ img_ids = sorted(lvis_api.imgs.keys())
196
+ # imgs is a list of dicts, each looks something like:
197
+ # {'license': 4,
198
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
199
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
200
+ # 'height': 427,
201
+ # 'width': 640,
202
+ # 'date_captured': '2013-11-17 05:57:24',
203
+ # 'id': 1268}
204
+ imgs = lvis_api.load_imgs(img_ids)
205
+ logger = logging.getLogger(__name__)
206
+ logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file))
207
+ # anns is a list[list[dict]], where each dict is an annotation
208
+ # record for an object. The inner list enumerates the objects in an image
209
+ # and the outer list enumerates over images.
210
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
211
+
212
+ _verify_annotations_have_unique_ids(annotations_json_file, anns)
213
+ dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns)
214
+ return dataset_records
215
+
216
+
217
+ def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None:
218
+ """
219
+ Registers provided LVIS DensePose dataset
220
+
221
+ Args:
222
+ dataset_data: CocoDatasetInfo
223
+ Dataset data
224
+ datasets_root: Optional[str]
225
+ Datasets root folder (default: None)
226
+ """
227
+ annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath)
228
+ images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root)
229
+
230
+ def load_annotations():
231
+ return load_lvis_json(
232
+ annotations_json_file=annotations_fpath,
233
+ image_root=images_root,
234
+ dataset_name=dataset_data.name,
235
+ )
236
+
237
+ DatasetCatalog.register(dataset_data.name, load_annotations)
238
+ MetadataCatalog.get(dataset_data.name).set(
239
+ json_file=annotations_fpath,
240
+ image_root=images_root,
241
+ evaluator_type="lvis",
242
+ **get_metadata(DENSEPOSE_METADATA_URL_PREFIX),
243
+ )
244
+
245
+
246
+ def register_datasets(
247
+ datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None
248
+ ) -> None:
249
+ """
250
+ Registers provided LVIS DensePose datasets
251
+
252
+ Args:
253
+ datasets_data: Iterable[CocoDatasetInfo]
254
+ An iterable of dataset datas
255
+ datasets_root: Optional[str]
256
+ Datasets root folder (default: None)
257
+ """
258
+ for dataset_data in datasets_data:
259
+ register_dataset(dataset_data, datasets_root)
CatVTON/densepose/data/samplers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (615 Bytes). View file
 
CatVTON/densepose/data/samplers/__pycache__/densepose_base.cpython-39.pyc ADDED
Binary file (6.17 kB). View file
 
CatVTON/densepose/data/samplers/__pycache__/densepose_confidence_based.cpython-39.pyc ADDED
Binary file (4.55 kB). View file
 
CatVTON/densepose/data/samplers/__pycache__/densepose_cse_base.cpython-39.pyc ADDED
Binary file (5.01 kB). View file
 
CatVTON/densepose/data/samplers/__pycache__/densepose_cse_confidence_based.cpython-39.pyc ADDED
Binary file (4.92 kB). View file
 
CatVTON/densepose/data/samplers/__pycache__/densepose_cse_uniform.cpython-39.pyc ADDED
Binary file (539 Bytes). View file
 
CatVTON/densepose/data/samplers/__pycache__/densepose_uniform.cpython-39.pyc ADDED
Binary file (1.8 kB). View file
 
CatVTON/densepose/data/samplers/__pycache__/mask_from_densepose.cpython-39.pyc ADDED
Binary file (1.3 kB). View file
 
CatVTON/densepose/data/samplers/__pycache__/prediction_to_gt.cpython-39.pyc ADDED
Binary file (3.42 kB). View file
 
CatVTON/densepose/data/samplers/densepose_base.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BoxMode, Instances
10
+
11
+ from densepose.converters import ToChartResultConverter
12
+ from densepose.converters.base import IntTupleBox, make_int_box
13
+ from densepose.structures import DensePoseDataRelative, DensePoseList
14
+
15
+
16
+ class DensePoseBaseSampler:
17
+ """
18
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
19
+ Samples for each class are drawn according to some distribution over all pixels estimated
20
+ to belong to that class.
21
+ """
22
+
23
+ def __init__(self, count_per_class: int = 8):
24
+ """
25
+ Constructor
26
+
27
+ Args:
28
+ count_per_class (int): the sampler produces at most `count_per_class`
29
+ samples for each category
30
+ """
31
+ self.count_per_class = count_per_class
32
+
33
+ def __call__(self, instances: Instances) -> DensePoseList:
34
+ """
35
+ Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
36
+ into DensePose annotations data (an instance of `DensePoseList`)
37
+ """
38
+ boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
39
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
40
+ dp_datas = []
41
+ for i in range(len(boxes_xywh_abs)):
42
+ annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
43
+ annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
44
+ instances[i].pred_densepose
45
+ )
46
+ dp_datas.append(DensePoseDataRelative(annotation_i))
47
+ # create densepose annotations on CPU
48
+ dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
49
+ return dp_list
50
+
51
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
52
+ """
53
+ Sample DensPoseDataRelative from estimation results
54
+ """
55
+ labels, dp_result = self._produce_labels_and_results(instance)
56
+ annotation = {
57
+ DensePoseDataRelative.X_KEY: [],
58
+ DensePoseDataRelative.Y_KEY: [],
59
+ DensePoseDataRelative.U_KEY: [],
60
+ DensePoseDataRelative.V_KEY: [],
61
+ DensePoseDataRelative.I_KEY: [],
62
+ }
63
+ n, h, w = dp_result.shape
64
+ for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
65
+ # indices - tuple of 3 1D tensors of size k
66
+ # 0: index along the first dimension N
67
+ # 1: index along H dimension
68
+ # 2: index along W dimension
69
+ indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
70
+ # values - an array of size [n, k]
71
+ # n: number of channels (U, V, confidences)
72
+ # k: number of points labeled with part_id
73
+ values = dp_result[indices].view(n, -1)
74
+ k = values.shape[1]
75
+ count = min(self.count_per_class, k)
76
+ if count <= 0:
77
+ continue
78
+ index_sample = self._produce_index_sample(values, count)
79
+ sampled_values = values[:, index_sample]
80
+ sampled_y = indices[1][index_sample] + 0.5
81
+ sampled_x = indices[2][index_sample] + 0.5
82
+ # prepare / normalize data
83
+ x = (sampled_x / w * 256.0).cpu().tolist()
84
+ y = (sampled_y / h * 256.0).cpu().tolist()
85
+ u = sampled_values[0].clamp(0, 1).cpu().tolist()
86
+ v = sampled_values[1].clamp(0, 1).cpu().tolist()
87
+ fine_segm_labels = [part_id] * count
88
+ # extend annotations
89
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
90
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
91
+ annotation[DensePoseDataRelative.U_KEY].extend(u)
92
+ annotation[DensePoseDataRelative.V_KEY].extend(v)
93
+ annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
94
+ return annotation
95
+
96
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
97
+ """
98
+ Abstract method to produce a sample of indices to select data
99
+ To be implemented in descendants
100
+
101
+ Args:
102
+ values (torch.Tensor): an array of size [n, k] that contains
103
+ estimated values (U, V, confidences);
104
+ n: number of channels (U, V, confidences)
105
+ k: number of points labeled with part_id
106
+ count (int): number of samples to produce, should be positive and <= k
107
+
108
+ Return:
109
+ list(int): indices of values (along axis 1) selected as a sample
110
+ """
111
+ raise NotImplementedError
112
+
113
+ def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """
115
+ Method to get labels and DensePose results from an instance
116
+
117
+ Args:
118
+ instance (Instances): an instance of `DensePoseChartPredictorOutput`
119
+
120
+ Return:
121
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
122
+ dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
123
+ """
124
+ converter = ToChartResultConverter
125
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
126
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
127
+ return labels, dp_result
128
+
129
+ def _resample_mask(self, output: Any) -> torch.Tensor:
130
+ """
131
+ Convert DensePose predictor output to segmentation annotation - tensors of size
132
+ (256, 256) and type `int64`.
133
+
134
+ Args:
135
+ output: DensePose predictor output with the following attributes:
136
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
137
+ segmentation scores
138
+ - fine_segm: tensor of size [N, C, H, W] with unnormalized fine
139
+ segmentation scores
140
+ Return:
141
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
142
+ where S = DensePoseDataRelative.MASK_SIZE
143
+ """
144
+ sz = DensePoseDataRelative.MASK_SIZE
145
+ S = (
146
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
147
+ .argmax(dim=1)
148
+ .long()
149
+ )
150
+ I = (
151
+ (
152
+ F.interpolate(
153
+ output.fine_segm,
154
+ (sz, sz),
155
+ mode="bilinear",
156
+ align_corners=False,
157
+ ).argmax(dim=1)
158
+ * (S > 0).long()
159
+ )
160
+ .squeeze()
161
+ .cpu()
162
+ )
163
+ # Map fine segmentation results to coarse segmentation ground truth
164
+ # TODO: extract this into separate classes
165
+ # coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
166
+ # 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
167
+ # 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
168
+ # 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
169
+ # 14 = Head
170
+ # fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
171
+ # 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
172
+ # 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
173
+ # 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
174
+ # 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
175
+ # 20, 22 = Lower Arm Right, 23, 24 = Head
176
+ FINE_TO_COARSE_SEGMENTATION = {
177
+ 1: 1,
178
+ 2: 1,
179
+ 3: 2,
180
+ 4: 3,
181
+ 5: 4,
182
+ 6: 5,
183
+ 7: 6,
184
+ 8: 7,
185
+ 9: 6,
186
+ 10: 7,
187
+ 11: 8,
188
+ 12: 9,
189
+ 13: 8,
190
+ 14: 9,
191
+ 15: 10,
192
+ 16: 11,
193
+ 17: 10,
194
+ 18: 11,
195
+ 19: 12,
196
+ 20: 13,
197
+ 21: 12,
198
+ 22: 13,
199
+ 23: 14,
200
+ 24: 14,
201
+ }
202
+ mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
203
+ for i in range(DensePoseDataRelative.N_PART_LABELS):
204
+ mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
205
+ return mask
CatVTON/densepose/data/samplers/densepose_confidence_based.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+
9
+ from densepose.converters import ToChartResultConverterWithConfidences
10
+
11
+ from .densepose_base import DensePoseBaseSampler
12
+
13
+
14
+ class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
15
+ """
16
+ Samples DensePose data from DensePose predictions.
17
+ Samples for each class are drawn using confidence value estimates.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ confidence_channel: str,
23
+ count_per_class: int = 8,
24
+ search_count_multiplier: Optional[float] = None,
25
+ search_proportion: Optional[float] = None,
26
+ ):
27
+ """
28
+ Constructor
29
+
30
+ Args:
31
+ confidence_channel (str): confidence channel to use for sampling;
32
+ possible values:
33
+ "sigma_2": confidences for UV values
34
+ "fine_segm_confidence": confidences for fine segmentation
35
+ "coarse_segm_confidence": confidences for coarse segmentation
36
+ (default: "sigma_2")
37
+ count_per_class (int): the sampler produces at most `count_per_class`
38
+ samples for each category (default: 8)
39
+ search_count_multiplier (float or None): if not None, the total number
40
+ of the most confident estimates of a given class to consider is
41
+ defined as `min(search_count_multiplier * count_per_class, N)`,
42
+ where `N` is the total number of estimates of the class; cannot be
43
+ specified together with `search_proportion` (default: None)
44
+ search_proportion (float or None): if not None, the total number of the
45
+ of the most confident estimates of a given class to consider is
46
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
47
+ where `N` is the total number of estimates of the class; cannot be
48
+ specified together with `search_count_multiplier` (default: None)
49
+ """
50
+ super().__init__(count_per_class)
51
+ self.confidence_channel = confidence_channel
52
+ self.search_count_multiplier = search_count_multiplier
53
+ self.search_proportion = search_proportion
54
+ assert (search_count_multiplier is None) or (search_proportion is None), (
55
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
56
+ f"and search_proportion (={search_proportion})"
57
+ )
58
+
59
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
60
+ """
61
+ Produce a sample of indices to select data based on confidences
62
+
63
+ Args:
64
+ values (torch.Tensor): an array of size [n, k] that contains
65
+ estimated values (U, V, confidences);
66
+ n: number of channels (U, V, confidences)
67
+ k: number of points labeled with part_id
68
+ count (int): number of samples to produce, should be positive and <= k
69
+
70
+ Return:
71
+ list(int): indices of values (along axis 1) selected as a sample
72
+ """
73
+ k = values.shape[1]
74
+ if k == count:
75
+ index_sample = list(range(k))
76
+ else:
77
+ # take the best count * search_count_multiplier pixels,
78
+ # sample from them uniformly
79
+ # (here best = smallest variance)
80
+ _, sorted_confidence_indices = torch.sort(values[2])
81
+ if self.search_count_multiplier is not None:
82
+ search_count = min(int(count * self.search_count_multiplier), k)
83
+ elif self.search_proportion is not None:
84
+ search_count = min(max(int(k * self.search_proportion), count), k)
85
+ else:
86
+ search_count = min(count, k)
87
+ sample_from_top = random.sample(range(search_count), count)
88
+ index_sample = sorted_confidence_indices[:search_count][sample_from_top]
89
+ return index_sample
90
+
91
+ def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Method to get labels and DensePose results from an instance, with confidences
94
+
95
+ Args:
96
+ instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
97
+
98
+ Return:
99
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
100
+ dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
101
+ stacked with the confidence channel
102
+ """
103
+ converter = ToChartResultConverterWithConfidences
104
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
105
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
106
+ dp_result = torch.cat(
107
+ (dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
108
+ )
109
+
110
+ return labels, dp_result
CatVTON/densepose/data/samplers/densepose_cse_uniform.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_cse_base import DensePoseCSEBaseSampler
6
+ from .densepose_uniform import DensePoseUniformSampler
7
+
8
+
9
+ class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
10
+ """
11
+ Uniform Sampler for CSE
12
+ """
13
+
14
+ pass
CatVTON/densepose/data/samplers/mask_from_densepose.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.structures import BitMasks, Instances
6
+
7
+ from densepose.converters import ToMaskConverter
8
+
9
+
10
+ class MaskFromDensePoseSampler:
11
+ """
12
+ Produce mask GT from DensePose predictions
13
+ This sampler simply converts DensePose predictions to BitMasks
14
+ that a contain a bool tensor of the size of the input image
15
+ """
16
+
17
+ def __call__(self, instances: Instances) -> BitMasks:
18
+ """
19
+ Converts predicted data from `instances` into the GT mask data
20
+
21
+ Args:
22
+ instances (Instances): predicted results, expected to have `pred_densepose` field
23
+
24
+ Returns:
25
+ Boolean Tensor of the size of the input image that has non-zero
26
+ values at pixels that are estimated to belong to the detected object
27
+ """
28
+ return ToMaskConverter.convert(
29
+ instances.pred_densepose, instances.pred_boxes, instances.image_size
30
+ )
CatVTON/densepose/data/samplers/prediction_to_gt.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from detectron2.structures import Instances
9
+
10
+ ModelOutput = Dict[str, Any]
11
+ SampledData = Dict[str, Any]
12
+
13
+
14
+ @dataclass
15
+ class _Sampler:
16
+ """
17
+ Sampler registry entry that contains:
18
+ - src (str): source field to sample from (deleted after sampling)
19
+ - dst (Optional[str]): destination field to sample to, if not None
20
+ - func (Optional[Callable: Any -> Any]): function that performs sampling,
21
+ if None, reference copy is performed
22
+ """
23
+
24
+ src: str
25
+ dst: Optional[str]
26
+ func: Optional[Callable[[Any], Any]]
27
+
28
+
29
+ class PredictionToGroundTruthSampler:
30
+ """
31
+ Sampler implementation that converts predictions to GT using registered
32
+ samplers for different fields of `Instances`.
33
+ """
34
+
35
+ def __init__(self, dataset_name: str = ""):
36
+ self.dataset_name = dataset_name
37
+ self._samplers = {}
38
+ self.register_sampler("pred_boxes", "gt_boxes", None)
39
+ self.register_sampler("pred_classes", "gt_classes", None)
40
+ # delete scores
41
+ self.register_sampler("scores")
42
+
43
+ def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
44
+ """
45
+ Transform model output into ground truth data through sampling
46
+
47
+ Args:
48
+ model_output (Dict[str, Any]): model output
49
+ Returns:
50
+ Dict[str, Any]: sampled data
51
+ """
52
+ for model_output_i in model_output:
53
+ instances: Instances = model_output_i["instances"]
54
+ # transform data in each field
55
+ for _, sampler in self._samplers.items():
56
+ if not instances.has(sampler.src) or sampler.dst is None:
57
+ continue
58
+ if sampler.func is None:
59
+ instances.set(sampler.dst, instances.get(sampler.src))
60
+ else:
61
+ instances.set(sampler.dst, sampler.func(instances))
62
+ # delete model output data that was transformed
63
+ for _, sampler in self._samplers.items():
64
+ if sampler.src != sampler.dst and instances.has(sampler.src):
65
+ instances.remove(sampler.src)
66
+ model_output_i["dataset"] = self.dataset_name
67
+ return model_output
68
+
69
+ def register_sampler(
70
+ self,
71
+ prediction_attr: str,
72
+ gt_attr: Optional[str] = None,
73
+ func: Optional[Callable[[Any], Any]] = None,
74
+ ):
75
+ """
76
+ Register sampler for a field
77
+
78
+ Args:
79
+ prediction_attr (str): field to replace with a sampled value
80
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
81
+ func (Optional[Callable: Any -> Any]): sampler function
82
+ """
83
+ self._samplers[(prediction_attr, gt_attr)] = _Sampler(
84
+ src=prediction_attr, dst=gt_attr, func=func
85
+ )
86
+
87
+ def remove_sampler(
88
+ self,
89
+ prediction_attr: str,
90
+ gt_attr: Optional[str] = None,
91
+ ):
92
+ """
93
+ Remove sampler for a field
94
+
95
+ Args:
96
+ prediction_attr (str): field to replace with a sampled value
97
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
98
+ """
99
+ assert (prediction_attr, gt_attr) in self._samplers
100
+ del self._samplers[(prediction_attr, gt_attr)]
CatVTON/densepose/data/transform/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .image import ImageResizeTransform
CatVTON/densepose/data/transform/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (219 Bytes). View file
 
CatVTON/densepose/data/transform/__pycache__/image.cpython-39.pyc ADDED
Binary file (1.68 kB). View file
 
CatVTON/densepose/data/transform/image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import torch
6
+
7
+
8
+ class ImageResizeTransform:
9
+ """
10
+ Transform that resizes images loaded from a dataset
11
+ (BGR data in NCHW channel order, typically uint8) to a format ready to be
12
+ consumed by DensePose training (BGR float32 data in NCHW channel order)
13
+ """
14
+
15
+ def __init__(self, min_size: int = 800, max_size: int = 1333):
16
+ self.min_size = min_size
17
+ self.max_size = max_size
18
+
19
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Args:
22
+ images (torch.Tensor): tensor of size [N, 3, H, W] that contains
23
+ BGR data (typically in uint8)
24
+ Returns:
25
+ images (torch.Tensor): tensor of size [N, 3, H1, W1] where
26
+ H1 and W1 are chosen to respect the specified min and max sizes
27
+ and preserve the original aspect ratio, the data channels
28
+ follow BGR order and the data type is `torch.float32`
29
+ """
30
+ # resize with min size
31
+ images = images.float()
32
+ min_size = min(images.shape[-2:])
33
+ max_size = max(images.shape[-2:])
34
+ scale = min(self.min_size / min_size, self.max_size / max_size)
35
+ images = torch.nn.functional.interpolate(
36
+ images,
37
+ scale_factor=scale,
38
+ mode="bilinear",
39
+ align_corners=False,
40
+ )
41
+ return images
CatVTON/detectron2/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ from .utils.env import setup_environment
4
+
5
+ setup_environment()
6
+
7
+
8
+ # This line will be programatically read/write by setup.py.
9
+ # Leave them at the bottom of this file and don't touch them.
10
+ __version__ = "0.6"
CatVTON/detectron2/checkpoint/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # File:
4
+
5
+
6
+ from . import catalog as _UNUSED # register the handler
7
+ from .detection_checkpoint import DetectionCheckpointer
8
+ from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
9
+
10
+ __all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
CatVTON/detectron2/checkpoint/c2_model_loading.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import re
5
+ from typing import Dict, List
6
+ import torch
7
+
8
+
9
+ def convert_basic_c2_names(original_keys):
10
+ """
11
+ Apply some basic name conversion to names in C2 weights.
12
+ It only deals with typical backbone models.
13
+
14
+ Args:
15
+ original_keys (list[str]):
16
+ Returns:
17
+ list[str]: The same number of strings matching those in original_keys.
18
+ """
19
+ layer_keys = copy.deepcopy(original_keys)
20
+ layer_keys = [
21
+ {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
22
+ ] # some hard-coded mappings
23
+
24
+ layer_keys = [k.replace("_", ".") for k in layer_keys]
25
+ layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
26
+ layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
27
+ # Uniform both bn and gn names to "norm"
28
+ layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
29
+ layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
30
+ layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
31
+ layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
32
+ layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
33
+ layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
34
+ layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
35
+ layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
36
+ layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
37
+ layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
38
+
39
+ # stem
40
+ layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
41
+ # to avoid mis-matching with "conv1" in other components (e.g. detection head)
42
+ layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
43
+
44
+ # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
45
+ # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
46
+ # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
47
+ # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
48
+ # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
49
+
50
+ # blocks
51
+ layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
52
+ layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
53
+ layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
54
+ layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
55
+
56
+ # DensePose substitutions
57
+ layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
58
+ layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
59
+ layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
60
+ layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
61
+ layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
62
+ return layer_keys
63
+
64
+
65
+ def convert_c2_detectron_names(weights):
66
+ """
67
+ Map Caffe2 Detectron weight names to Detectron2 names.
68
+
69
+ Args:
70
+ weights (dict): name -> tensor
71
+
72
+ Returns:
73
+ dict: detectron2 names -> tensor
74
+ dict: detectron2 names -> C2 names
75
+ """
76
+ logger = logging.getLogger(__name__)
77
+ logger.info("Renaming Caffe2 weights ......")
78
+ original_keys = sorted(weights.keys())
79
+ layer_keys = copy.deepcopy(original_keys)
80
+
81
+ layer_keys = convert_basic_c2_names(layer_keys)
82
+
83
+ # --------------------------------------------------------------------------
84
+ # RPN hidden representation conv
85
+ # --------------------------------------------------------------------------
86
+ # FPN case
87
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
88
+ # shared for all other levels, hence the appearance of "fpn2"
89
+ layer_keys = [
90
+ k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
91
+ ]
92
+ # Non-FPN case
93
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
94
+
95
+ # --------------------------------------------------------------------------
96
+ # RPN box transformation conv
97
+ # --------------------------------------------------------------------------
98
+ # FPN case (see note above about "fpn2")
99
+ layer_keys = [
100
+ k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
101
+ for k in layer_keys
102
+ ]
103
+ layer_keys = [
104
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
105
+ for k in layer_keys
106
+ ]
107
+ # Non-FPN case
108
+ layer_keys = [
109
+ k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
110
+ ]
111
+ layer_keys = [
112
+ k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
113
+ for k in layer_keys
114
+ ]
115
+
116
+ # --------------------------------------------------------------------------
117
+ # Fast R-CNN box head
118
+ # --------------------------------------------------------------------------
119
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
120
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
121
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
122
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
123
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
124
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
125
+
126
+ # --------------------------------------------------------------------------
127
+ # FPN lateral and output convolutions
128
+ # --------------------------------------------------------------------------
129
+ def fpn_map(name):
130
+ """
131
+ Look for keys with the following patterns:
132
+ 1) Starts with "fpn.inner."
133
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
134
+ Meaning: These are lateral pathway convolutions
135
+ 2) Starts with "fpn.res"
136
+ Example: "fpn.res2.2.sum.weight"
137
+ Meaning: These are FPN output convolutions
138
+ """
139
+ splits = name.split(".")
140
+ norm = ".norm" if "norm" in splits else ""
141
+ if name.startswith("fpn.inner."):
142
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
143
+ stage = int(splits[2][len("res") :])
144
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
145
+ elif name.startswith("fpn.res"):
146
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
147
+ stage = int(splits[1][len("res") :])
148
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
149
+ return name
150
+
151
+ layer_keys = [fpn_map(k) for k in layer_keys]
152
+
153
+ # --------------------------------------------------------------------------
154
+ # Mask R-CNN mask head
155
+ # --------------------------------------------------------------------------
156
+ # roi_heads.StandardROIHeads case
157
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
158
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
159
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
160
+ # roi_heads.Res5ROIHeads case
161
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
162
+
163
+ # --------------------------------------------------------------------------
164
+ # Keypoint R-CNN head
165
+ # --------------------------------------------------------------------------
166
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
167
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
168
+ layer_keys = [
169
+ k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
170
+ ]
171
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
172
+
173
+ # --------------------------------------------------------------------------
174
+ # Done with replacements
175
+ # --------------------------------------------------------------------------
176
+ assert len(set(layer_keys)) == len(layer_keys)
177
+ assert len(original_keys) == len(layer_keys)
178
+
179
+ new_weights = {}
180
+ new_keys_to_original_keys = {}
181
+ for orig, renamed in zip(original_keys, layer_keys):
182
+ new_keys_to_original_keys[renamed] = orig
183
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
184
+ # remove the meaningless prediction weight for background class
185
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
186
+ new_weights[renamed] = weights[orig][new_start_idx:]
187
+ logger.info(
188
+ "Remove prediction weight for background class in {}. The shape changes from "
189
+ "{} to {}.".format(
190
+ renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
191
+ )
192
+ )
193
+ elif renamed.startswith("cls_score."):
194
+ # move weights of bg class from original index 0 to last index
195
+ logger.info(
196
+ "Move classification weights for background class in {} from index 0 to "
197
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
198
+ )
199
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
200
+ else:
201
+ new_weights[renamed] = weights[orig]
202
+
203
+ return new_weights, new_keys_to_original_keys
204
+
205
+
206
+ # Note the current matching is not symmetric.
207
+ # it assumes model_state_dict will have longer names.
208
+ def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
209
+ """
210
+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
211
+ converted to match model_state_dict with heuristics. The returned dict can be later
212
+ loaded with fvcore checkpointer.
213
+ If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
214
+ model and will be renamed at first.
215
+
216
+ Strategy: suppose that the models that we will create will have prefixes appended
217
+ to each of its keys, for example due to an extra level of nesting that the original
218
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
219
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
220
+ res2.conv1.weight. We thus want to match both parameters together.
221
+ For that, we look for each model weight, look among all loaded keys if there is one
222
+ that is a suffix of the current weight name, and use it if that's the case.
223
+ If multiple matches exist, take the one with longest size
224
+ of the corresponding name. For example, for the same model as before, the pretrained
225
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
226
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
227
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
228
+ """
229
+ model_keys = sorted(model_state_dict.keys())
230
+ if c2_conversion:
231
+ ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
232
+ # original_keys: the name in the original dict (before renaming)
233
+ else:
234
+ original_keys = {x: x for x in ckpt_state_dict.keys()}
235
+ ckpt_keys = sorted(ckpt_state_dict.keys())
236
+
237
+ def match(a, b):
238
+ # Matched ckpt_key should be a complete (starts with '.') suffix.
239
+ # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
240
+ # but matches whatever_conv1 or mesh_head.whatever_conv1.
241
+ return a == b or a.endswith("." + b)
242
+
243
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
244
+ # ckpt_key string, if it matches
245
+ match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
246
+ match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
247
+ # use the matched one with longest size in case of multiple matches
248
+ max_match_size, idxs = match_matrix.max(1)
249
+ # remove indices that correspond to no-match
250
+ idxs[max_match_size == 0] = -1
251
+
252
+ logger = logging.getLogger(__name__)
253
+ # matched_pairs (matched checkpoint key --> matched model key)
254
+ matched_keys = {}
255
+ result_state_dict = {}
256
+ for idx_model, idx_ckpt in enumerate(idxs.tolist()):
257
+ if idx_ckpt == -1:
258
+ continue
259
+ key_model = model_keys[idx_model]
260
+ key_ckpt = ckpt_keys[idx_ckpt]
261
+ value_ckpt = ckpt_state_dict[key_ckpt]
262
+ shape_in_model = model_state_dict[key_model].shape
263
+
264
+ if shape_in_model != value_ckpt.shape:
265
+ logger.warning(
266
+ "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
267
+ key_ckpt, value_ckpt.shape, key_model, shape_in_model
268
+ )
269
+ )
270
+ logger.warning(
271
+ "{} will not be loaded. Please double check and see if this is desired.".format(
272
+ key_ckpt
273
+ )
274
+ )
275
+ continue
276
+
277
+ assert key_model not in result_state_dict
278
+ result_state_dict[key_model] = value_ckpt
279
+ if key_ckpt in matched_keys: # already added to matched_keys
280
+ logger.error(
281
+ "Ambiguity found for {} in checkpoint!"
282
+ "It matches at least two keys in the model ({} and {}).".format(
283
+ key_ckpt, key_model, matched_keys[key_ckpt]
284
+ )
285
+ )
286
+ raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
287
+
288
+ matched_keys[key_ckpt] = key_model
289
+
290
+ # logging:
291
+ matched_model_keys = sorted(matched_keys.values())
292
+ if len(matched_model_keys) == 0:
293
+ logger.warning("No weights in checkpoint matched with model.")
294
+ return ckpt_state_dict
295
+ common_prefix = _longest_common_prefix(matched_model_keys)
296
+ rev_matched_keys = {v: k for k, v in matched_keys.items()}
297
+ original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
298
+
299
+ model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
300
+ table = []
301
+ memo = set()
302
+ for key_model in matched_model_keys:
303
+ if key_model in memo:
304
+ continue
305
+ if key_model in model_key_groups:
306
+ group = model_key_groups[key_model]
307
+ memo |= set(group)
308
+ shapes = [tuple(model_state_dict[k].shape) for k in group]
309
+ table.append(
310
+ (
311
+ _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
312
+ _group_str([original_keys[k] for k in group]),
313
+ " ".join([str(x).replace(" ", "") for x in shapes]),
314
+ )
315
+ )
316
+ else:
317
+ key_checkpoint = original_keys[key_model]
318
+ shape = str(tuple(model_state_dict[key_model].shape))
319
+ table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
320
+ submodule_str = common_prefix[:-1] if common_prefix else "model"
321
+ logger.info(
322
+ f"Following weights matched with submodule {submodule_str} - Total num: {len(table)}"
323
+ )
324
+
325
+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
326
+ for k in unmatched_ckpt_keys:
327
+ result_state_dict[k] = ckpt_state_dict[k]
328
+ return result_state_dict
329
+
330
+
331
+ def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
332
+ """
333
+ Params in the same submodule are grouped together.
334
+
335
+ Args:
336
+ keys: names of all parameters
337
+ original_names: mapping from parameter name to their name in the checkpoint
338
+
339
+ Returns:
340
+ dict[name -> all other names in the same group]
341
+ """
342
+
343
+ def _submodule_name(key):
344
+ pos = key.rfind(".")
345
+ if pos < 0:
346
+ return None
347
+ prefix = key[: pos + 1]
348
+ return prefix
349
+
350
+ all_submodules = [_submodule_name(k) for k in keys]
351
+ all_submodules = [x for x in all_submodules if x]
352
+ all_submodules = sorted(all_submodules, key=len)
353
+
354
+ ret = {}
355
+ for prefix in all_submodules:
356
+ group = [k for k in keys if k.startswith(prefix)]
357
+ if len(group) <= 1:
358
+ continue
359
+ original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
360
+ if len(original_name_lcp) == 0:
361
+ # don't group weights if original names don't share prefix
362
+ continue
363
+
364
+ for k in group:
365
+ if k in ret:
366
+ continue
367
+ ret[k] = group
368
+ return ret
369
+
370
+
371
+ def _longest_common_prefix(names: List[str]) -> str:
372
+ """
373
+ ["abc.zfg", "abc.zef"] -> "abc."
374
+ """
375
+ names = [n.split(".") for n in names]
376
+ m1, m2 = min(names), max(names)
377
+ ret = [a for a, b in zip(m1, m2) if a == b]
378
+ ret = ".".join(ret) + "." if len(ret) else ""
379
+ return ret
380
+
381
+
382
+ def _longest_common_prefix_str(names: List[str]) -> str:
383
+ m1, m2 = min(names), max(names)
384
+ lcp = []
385
+ for a, b in zip(m1, m2):
386
+ if a == b:
387
+ lcp.append(a)
388
+ else:
389
+ break
390
+ lcp = "".join(lcp)
391
+ return lcp
392
+
393
+
394
+ def _group_str(names: List[str]) -> str:
395
+ """
396
+ Turn "common1", "common2", "common3" into "common{1,2,3}"
397
+ """
398
+ lcp = _longest_common_prefix_str(names)
399
+ rest = [x[len(lcp) :] for x in names]
400
+ rest = "{" + ",".join(rest) + "}"
401
+ ret = lcp + rest
402
+
403
+ # add some simplification for BN specifically
404
+ ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
405
+ ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
406
+ return ret
CatVTON/detectron2/checkpoint/catalog.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+
4
+ from detectron2.utils.file_io import PathHandler, PathManager
5
+
6
+
7
+ class ModelCatalog:
8
+ """
9
+ Store mappings from names to third-party models.
10
+ """
11
+
12
+ S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
13
+
14
+ # MSRA models have STRIDE_IN_1X1=True. False otherwise.
15
+ # NOTE: all BN models here have fused BN into an affine layer.
16
+ # As a result, you should only load them to a model with "FrozenBN".
17
+ # Loading them to a model with regular BN or SyncBN is wrong.
18
+ # Even when loaded to FrozenBN, it is still different from affine by an epsilon,
19
+ # which should be negligible for training.
20
+ # NOTE: all models here uses PIXEL_STD=[1,1,1]
21
+ # NOTE: Most of the BN models here are no longer used. We use the
22
+ # re-converted pre-trained models under detectron2 model zoo instead.
23
+ C2_IMAGENET_MODELS = {
24
+ "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
25
+ "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
26
+ "FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
27
+ "FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
28
+ "FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
29
+ "FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
30
+ "FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
31
+ }
32
+
33
+ C2_DETECTRON_PATH_FORMAT = (
34
+ "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
35
+ )
36
+
37
+ C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
38
+ C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
39
+
40
+ # format: {model_name} -> part of the url
41
+ C2_DETECTRON_MODELS = {
42
+ "35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
43
+ "35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
44
+ "35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
45
+ "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
46
+ "35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
47
+ "35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
48
+ "35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
49
+ "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
50
+ "48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
51
+ "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
52
+ "35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
53
+ "35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
54
+ "36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
55
+ }
56
+
57
+ @staticmethod
58
+ def get(name):
59
+ if name.startswith("Caffe2Detectron/COCO"):
60
+ return ModelCatalog._get_c2_detectron_baseline(name)
61
+ if name.startswith("ImageNetPretrained/"):
62
+ return ModelCatalog._get_c2_imagenet_pretrained(name)
63
+ raise RuntimeError("model not present in the catalog: {}".format(name))
64
+
65
+ @staticmethod
66
+ def _get_c2_imagenet_pretrained(name):
67
+ prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
68
+ name = name[len("ImageNetPretrained/") :]
69
+ name = ModelCatalog.C2_IMAGENET_MODELS[name]
70
+ url = "/".join([prefix, name])
71
+ return url
72
+
73
+ @staticmethod
74
+ def _get_c2_detectron_baseline(name):
75
+ name = name[len("Caffe2Detectron/COCO/") :]
76
+ url = ModelCatalog.C2_DETECTRON_MODELS[name]
77
+ if "keypoint_rcnn" in name:
78
+ dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
79
+ else:
80
+ dataset = ModelCatalog.C2_DATASET_COCO
81
+
82
+ if "35998355/rpn_R-50-C4_1x" in name:
83
+ # this one model is somehow different from others ..
84
+ type = "rpn"
85
+ else:
86
+ type = "generalized_rcnn"
87
+
88
+ # Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
89
+ url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
90
+ prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
91
+ )
92
+ return url
93
+
94
+
95
+ class ModelCatalogHandler(PathHandler):
96
+ """
97
+ Resolve URL like catalog://.
98
+ """
99
+
100
+ PREFIX = "catalog://"
101
+
102
+ def _get_supported_prefixes(self):
103
+ return [self.PREFIX]
104
+
105
+ def _get_local_path(self, path, **kwargs):
106
+ logger = logging.getLogger(__name__)
107
+ catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
108
+ logger.info("Catalog entry {} points to {}".format(path, catalog_path))
109
+ return PathManager.get_local_path(catalog_path, **kwargs)
110
+
111
+ def _open(self, path, mode="r", **kwargs):
112
+ return PathManager.open(self._get_local_path(path), mode, **kwargs)
113
+
114
+
115
+ PathManager.register_handler(ModelCatalogHandler())
CatVTON/detectron2/checkpoint/detection_checkpoint.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+ import pickle
5
+ from urllib.parse import parse_qs, urlparse
6
+ import torch
7
+ from fvcore.common.checkpoint import Checkpointer
8
+ from torch.nn.parallel import DistributedDataParallel
9
+
10
+ import detectron2.utils.comm as comm
11
+ from detectron2.utils.file_io import PathManager
12
+
13
+ from .c2_model_loading import align_and_update_state_dicts
14
+
15
+
16
+ class DetectionCheckpointer(Checkpointer):
17
+ """
18
+ Same as :class:`Checkpointer`, but is able to:
19
+ 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models.
20
+ 2. correctly load checkpoints that are only available on the master worker
21
+ """
22
+
23
+ def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
24
+ is_main_process = comm.is_main_process()
25
+ super().__init__(
26
+ model,
27
+ save_dir,
28
+ save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
29
+ **checkpointables,
30
+ )
31
+ self.path_manager = PathManager
32
+ self._parsed_url_during_load = None
33
+
34
+ def load(self, path, *args, **kwargs):
35
+ assert self._parsed_url_during_load is None
36
+ need_sync = False
37
+ logger = logging.getLogger(__name__)
38
+ logger.info("[DetectionCheckpointer] Loading from {} ...".format(path))
39
+
40
+ if path and isinstance(self.model, DistributedDataParallel):
41
+ path = self.path_manager.get_local_path(path)
42
+ has_file = os.path.isfile(path)
43
+ all_has_file = comm.all_gather(has_file)
44
+ if not all_has_file[0]:
45
+ raise OSError(f"File {path} not found on main worker.")
46
+ if not all(all_has_file):
47
+ logger.warning(
48
+ f"Not all workers can read checkpoint {path}. "
49
+ "Training may fail to fully resume."
50
+ )
51
+ # TODO: broadcast the checkpoint file contents from main
52
+ # worker, and load from it instead.
53
+ need_sync = True
54
+ if not has_file:
55
+ path = None # don't load if not readable
56
+
57
+ if path:
58
+ parsed_url = urlparse(path)
59
+ self._parsed_url_during_load = parsed_url
60
+ path = parsed_url._replace(query="").geturl() # remove query from filename
61
+ path = self.path_manager.get_local_path(path)
62
+ ret = super().load(path, *args, **kwargs)
63
+
64
+ if need_sync:
65
+ logger.info("Broadcasting model states from main worker ...")
66
+ self.model._sync_params_and_buffers()
67
+ self._parsed_url_during_load = None # reset to None
68
+ return ret
69
+
70
+ def _load_file(self, filename):
71
+ if filename.endswith(".pkl"):
72
+ with PathManager.open(filename, "rb") as f:
73
+ data = pickle.load(f, encoding="latin1")
74
+ if "model" in data and "__author__" in data:
75
+ # file is in Detectron2 model zoo format
76
+ self.logger.info("Reading a file from '{}'".format(data["__author__"]))
77
+ return data
78
+ else:
79
+ # assume file is from Caffe2 / Detectron1 model zoo
80
+ if "blobs" in data:
81
+ # Detection models have "blobs", but ImageNet models don't
82
+ data = data["blobs"]
83
+ data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
84
+ return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
85
+ elif filename.endswith(".pyth"):
86
+ # assume file is from pycls; no one else seems to use the ".pyth" extension
87
+ with PathManager.open(filename, "rb") as f:
88
+ data = torch.load(f)
89
+ assert (
90
+ "model_state" in data
91
+ ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'."
92
+ model_state = {
93
+ k: v
94
+ for k, v in data["model_state"].items()
95
+ if not k.endswith("num_batches_tracked")
96
+ }
97
+ return {"model": model_state, "__author__": "pycls", "matching_heuristics": True}
98
+
99
+ loaded = self._torch_load(filename)
100
+ if "model" not in loaded:
101
+ loaded = {"model": loaded}
102
+ assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`"
103
+ parsed_url = self._parsed_url_during_load
104
+ queries = parse_qs(parsed_url.query)
105
+ if queries.pop("matching_heuristics", "False") == ["True"]:
106
+ loaded["matching_heuristics"] = True
107
+ if len(queries) > 0:
108
+ raise ValueError(
109
+ f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}"
110
+ )
111
+ return loaded
112
+
113
+ def _torch_load(self, f):
114
+ return super()._load_file(f)
115
+
116
+ def _load_model(self, checkpoint):
117
+ if checkpoint.get("matching_heuristics", False):
118
+ self._convert_ndarray_to_tensor(checkpoint["model"])
119
+ # convert weights by name-matching heuristics
120
+ checkpoint["model"] = align_and_update_state_dicts(
121
+ self.model.state_dict(),
122
+ checkpoint["model"],
123
+ c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
124
+ )
125
+ # for non-caffe2 models, use standard ways to load it
126
+ incompatible = super()._load_model(checkpoint)
127
+
128
+ model_buffers = dict(self.model.named_buffers(recurse=False))
129
+ for k in ["pixel_mean", "pixel_std"]:
130
+ # Ignore missing key message about pixel_mean/std.
131
+ # Though they may be missing in old checkpoints, they will be correctly
132
+ # initialized from config anyway.
133
+ if k in model_buffers:
134
+ try:
135
+ incompatible.missing_keys.remove(k)
136
+ except ValueError:
137
+ pass
138
+ for k in incompatible.unexpected_keys[:]:
139
+ # Ignore unexpected keys about cell anchors. They exist in old checkpoints
140
+ # but now they are non-persistent buffers and will not be in new checkpoints.
141
+ if "anchor_generator.cell_anchors" in k:
142
+ incompatible.unexpected_keys.remove(k)
143
+ return incompatible
CatVTON/detectron2/engine/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ from .launch import *
4
+ from .train_loop import *
5
+
6
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
7
+
8
+
9
+ # prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
10
+ # but still make them available here
11
+ from .hooks import *
12
+ from .defaults import (
13
+ create_ddp_model,
14
+ default_argument_parser,
15
+ default_setup,
16
+ default_writers,
17
+ DefaultPredictor,
18
+ DefaultTrainer,
19
+ )
CatVTON/detectron2/engine/defaults.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ """
5
+ This file contains components with some default boilerplate logic user may need
6
+ in training / testing. They will not work for everyone, but many users may find them useful.
7
+
8
+ The behavior of functions/classes in this file is subject to change,
9
+ since they are meant to represent the "common default behavior" people need in their projects.
10
+ """
11
+
12
+ import argparse
13
+ import logging
14
+ import os
15
+ import sys
16
+ import weakref
17
+ from collections import OrderedDict
18
+ from typing import Optional
19
+ import torch
20
+ from fvcore.nn.precise_bn import get_bn_modules
21
+ from omegaconf import OmegaConf
22
+ from torch.nn.parallel import DistributedDataParallel
23
+
24
+ import detectron2.data.transforms as T
25
+ from detectron2.checkpoint import DetectionCheckpointer
26
+ from detectron2.config import CfgNode, LazyConfig
27
+ from detectron2.data import (
28
+ MetadataCatalog,
29
+ build_detection_test_loader,
30
+ build_detection_train_loader,
31
+ )
32
+ from detectron2.evaluation import (
33
+ DatasetEvaluator,
34
+ inference_on_dataset,
35
+ print_csv_format,
36
+ verify_results,
37
+ )
38
+ from detectron2.modeling import build_model
39
+ from detectron2.solver import build_lr_scheduler, build_optimizer
40
+ from detectron2.utils import comm
41
+ from detectron2.utils.collect_env import collect_env_info
42
+ from detectron2.utils.env import seed_all_rng
43
+ from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
44
+ from detectron2.utils.file_io import PathManager
45
+ from detectron2.utils.logger import setup_logger
46
+
47
+ from . import hooks
48
+ from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
49
+
50
+ __all__ = [
51
+ "create_ddp_model",
52
+ "default_argument_parser",
53
+ "default_setup",
54
+ "default_writers",
55
+ "DefaultPredictor",
56
+ "DefaultTrainer",
57
+ ]
58
+
59
+
60
+ def create_ddp_model(model, *, fp16_compression=False, **kwargs):
61
+ """
62
+ Create a DistributedDataParallel model if there are >1 processes.
63
+
64
+ Args:
65
+ model: a torch.nn.Module
66
+ fp16_compression: add fp16 compression hooks to the ddp object.
67
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
68
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
69
+ """ # noqa
70
+ if comm.get_world_size() == 1:
71
+ return model
72
+ if "device_ids" not in kwargs:
73
+ kwargs["device_ids"] = [comm.get_local_rank()]
74
+ ddp = DistributedDataParallel(model, **kwargs)
75
+ if fp16_compression:
76
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
77
+
78
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
79
+ return ddp
80
+
81
+
82
+ def default_argument_parser(epilog=None):
83
+ """
84
+ Create a parser with some common arguments used by detectron2 users.
85
+
86
+ Args:
87
+ epilog (str): epilog passed to ArgumentParser describing the usage.
88
+
89
+ Returns:
90
+ argparse.ArgumentParser:
91
+ """
92
+ parser = argparse.ArgumentParser(
93
+ epilog=epilog
94
+ or f"""
95
+ Examples:
96
+
97
+ Run on single machine:
98
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
99
+
100
+ Change some config options:
101
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
102
+
103
+ Run on multiple machines:
104
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
105
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
106
+ """,
107
+ formatter_class=argparse.RawDescriptionHelpFormatter,
108
+ )
109
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
110
+ parser.add_argument(
111
+ "--resume",
112
+ action="store_true",
113
+ help="Whether to attempt to resume from the checkpoint directory. "
114
+ "See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
115
+ )
116
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
117
+ parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
118
+ parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
119
+ parser.add_argument(
120
+ "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
121
+ )
122
+
123
+ # PyTorch still may leave orphan processes in multi-gpu training.
124
+ # Therefore we use a deterministic way to obtain port,
125
+ # so that users are aware of orphan processes by seeing the port occupied.
126
+ port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
127
+ parser.add_argument(
128
+ "--dist-url",
129
+ default="tcp://127.0.0.1:{}".format(port),
130
+ help="initialization URL for pytorch distributed backend. See "
131
+ "https://pytorch.org/docs/stable/distributed.html for details.",
132
+ )
133
+ parser.add_argument(
134
+ "opts",
135
+ help="""
136
+ Modify config options at the end of the command. For Yacs configs, use
137
+ space-separated "PATH.KEY VALUE" pairs.
138
+ For python-based LazyConfig, use "path.key=value".
139
+ """.strip(),
140
+ default=None,
141
+ nargs=argparse.REMAINDER,
142
+ )
143
+ return parser
144
+
145
+
146
+ def _try_get_key(cfg, *keys, default=None):
147
+ """
148
+ Try select keys from cfg until the first key that exists. Otherwise return default.
149
+ """
150
+ if isinstance(cfg, CfgNode):
151
+ cfg = OmegaConf.create(cfg.dump())
152
+ for k in keys:
153
+ none = object()
154
+ p = OmegaConf.select(cfg, k, default=none)
155
+ if p is not none:
156
+ return p
157
+ return default
158
+
159
+
160
+ def _highlight(code, filename):
161
+ try:
162
+ import pygments
163
+ except ImportError:
164
+ return code
165
+
166
+ from pygments.lexers import Python3Lexer, YamlLexer
167
+ from pygments.formatters import Terminal256Formatter
168
+
169
+ lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
170
+ code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
171
+ return code
172
+
173
+
174
+ def default_setup(cfg, args):
175
+ """
176
+ Perform some basic common setups at the beginning of a job, including:
177
+
178
+ 1. Set up the detectron2 logger
179
+ 2. Log basic information about environment, cmdline arguments, and config
180
+ 3. Backup the config to the output directory
181
+
182
+ Args:
183
+ cfg (CfgNode or omegaconf.DictConfig): the full config to be used
184
+ args (argparse.NameSpace): the command line arguments to be logged
185
+ """
186
+ output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
187
+ if comm.is_main_process() and output_dir:
188
+ PathManager.mkdirs(output_dir)
189
+
190
+ rank = comm.get_rank()
191
+ setup_logger(output_dir, distributed_rank=rank, name="fvcore")
192
+ logger = setup_logger(output_dir, distributed_rank=rank)
193
+
194
+ logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
195
+ logger.info("Environment info:\n" + collect_env_info())
196
+
197
+ logger.info("Command line arguments: " + str(args))
198
+ if hasattr(args, "config_file") and args.config_file != "":
199
+ logger.info(
200
+ "Contents of args.config_file={}:\n{}".format(
201
+ args.config_file,
202
+ _highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
203
+ )
204
+ )
205
+
206
+ if comm.is_main_process() and output_dir:
207
+ # Note: some of our scripts may expect the existence of
208
+ # config.yaml in output directory
209
+ path = os.path.join(output_dir, "config.yaml")
210
+ if isinstance(cfg, CfgNode):
211
+ logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
212
+ with PathManager.open(path, "w") as f:
213
+ f.write(cfg.dump())
214
+ else:
215
+ LazyConfig.save(cfg, path)
216
+ logger.info("Full config saved to {}".format(path))
217
+
218
+ # make sure each worker has a different, yet deterministic seed if specified
219
+ seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
220
+ seed_all_rng(None if seed < 0 else seed + rank)
221
+
222
+ # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
223
+ # typical validation set.
224
+ if not (hasattr(args, "eval_only") and args.eval_only):
225
+ torch.backends.cudnn.benchmark = _try_get_key(
226
+ cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
227
+ )
228
+
229
+
230
+ def default_writers(output_dir: str, max_iter: Optional[int] = None):
231
+ """
232
+ Build a list of :class:`EventWriter` to be used.
233
+ It now consists of a :class:`CommonMetricPrinter`,
234
+ :class:`TensorboardXWriter` and :class:`JSONWriter`.
235
+
236
+ Args:
237
+ output_dir: directory to store JSON metrics and tensorboard events
238
+ max_iter: the total number of iterations
239
+
240
+ Returns:
241
+ list[EventWriter]: a list of :class:`EventWriter` objects.
242
+ """
243
+ PathManager.mkdirs(output_dir)
244
+ return [
245
+ # It may not always print what you want to see, since it prints "common" metrics only.
246
+ CommonMetricPrinter(max_iter),
247
+ JSONWriter(os.path.join(output_dir, "metrics.json")),
248
+ TensorboardXWriter(output_dir),
249
+ ]
250
+
251
+
252
+ class DefaultPredictor:
253
+ """
254
+ Create a simple end-to-end predictor with the given config that runs on
255
+ single device for a single input image.
256
+
257
+ Compared to using the model directly, this class does the following additions:
258
+
259
+ 1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
260
+ 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
261
+ 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
262
+ 4. Take one input image and produce a single output, instead of a batch.
263
+
264
+ This is meant for simple demo purposes, so it does the above steps automatically.
265
+ This is not meant for benchmarks or running complicated inference logic.
266
+ If you'd like to do anything more complicated, please refer to its source code as
267
+ examples to build and use the model manually.
268
+
269
+ Attributes:
270
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
271
+ cfg.DATASETS.TEST.
272
+
273
+ Examples:
274
+ ::
275
+ pred = DefaultPredictor(cfg)
276
+ inputs = cv2.imread("input.jpg")
277
+ outputs = pred(inputs)
278
+ """
279
+
280
+ def __init__(self, cfg):
281
+ self.cfg = cfg.clone() # cfg can be modified by model
282
+ self.model = build_model(self.cfg)
283
+ self.model.eval()
284
+ if len(cfg.DATASETS.TEST):
285
+ self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
286
+
287
+ checkpointer = DetectionCheckpointer(self.model)
288
+ checkpointer.load(cfg.MODEL.WEIGHTS)
289
+
290
+ self.aug = T.ResizeShortestEdge(
291
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
292
+ )
293
+
294
+ self.input_format = cfg.INPUT.FORMAT
295
+ assert self.input_format in ["RGB", "BGR"], self.input_format
296
+
297
+ def __call__(self, original_image):
298
+ """
299
+ Args:
300
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
301
+
302
+ Returns:
303
+ predictions (dict):
304
+ the output of the model for one image only.
305
+ See :doc:`/tutorials/models` for details about the format.
306
+ """
307
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
308
+ # Apply pre-processing to image.
309
+ if self.input_format == "RGB":
310
+ # whether the model expects BGR inputs or RGB
311
+ original_image = original_image[:, :, ::-1]
312
+ height, width = original_image.shape[:2]
313
+ image = self.aug.get_transform(original_image).apply_image(original_image)
314
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
315
+ image.to(self.cfg.MODEL.DEVICE)
316
+
317
+ inputs = {"image": image, "height": height, "width": width}
318
+
319
+ predictions = self.model([inputs])[0]
320
+ return predictions
321
+
322
+
323
+ class DefaultTrainer(TrainerBase):
324
+ """
325
+ A trainer with default training logic. It does the following:
326
+
327
+ 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
328
+ defined by the given config. Create a LR scheduler defined by the config.
329
+ 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
330
+ `resume_or_load` is called.
331
+ 3. Register a few common hooks defined by the config.
332
+
333
+ It is created to simplify the **standard model training workflow** and reduce code boilerplate
334
+ for users who only need the standard training workflow, with standard features.
335
+ It means this class makes *many assumptions* about your training logic that
336
+ may easily become invalid in a new research. In fact, any assumptions beyond those made in the
337
+ :class:`SimpleTrainer` are too much for research.
338
+
339
+ The code of this class has been annotated about restrictive assumptions it makes.
340
+ When they do not work for you, you're encouraged to:
341
+
342
+ 1. Overwrite methods of this class, OR:
343
+ 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
344
+ nothing else. You can then add your own hooks if needed. OR:
345
+ 3. Write your own training loop similar to `tools/plain_train_net.py`.
346
+
347
+ See the :doc:`/tutorials/training` tutorials for more details.
348
+
349
+ Note that the behavior of this class, like other functions/classes in
350
+ this file, is not stable, since it is meant to represent the "common default behavior".
351
+ It is only guaranteed to work well with the standard models and training workflow in detectron2.
352
+ To obtain more stable behavior, write your own training logic with other public APIs.
353
+
354
+ Examples:
355
+ ::
356
+ trainer = DefaultTrainer(cfg)
357
+ trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
358
+ trainer.train()
359
+
360
+ Attributes:
361
+ scheduler:
362
+ checkpointer (DetectionCheckpointer):
363
+ cfg (CfgNode):
364
+ """
365
+
366
+ def __init__(self, cfg):
367
+ """
368
+ Args:
369
+ cfg (CfgNode):
370
+ """
371
+ super().__init__()
372
+ logger = logging.getLogger("detectron2")
373
+ if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
374
+ setup_logger()
375
+ cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
376
+
377
+ # Assume these objects must be constructed in this order.
378
+ model = self.build_model(cfg)
379
+ optimizer = self.build_optimizer(cfg, model)
380
+ data_loader = self.build_train_loader(cfg)
381
+
382
+ model = create_ddp_model(model, broadcast_buffers=False)
383
+ self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
384
+ model, data_loader, optimizer
385
+ )
386
+
387
+ self.scheduler = self.build_lr_scheduler(cfg, optimizer)
388
+ self.checkpointer = DetectionCheckpointer(
389
+ # Assume you want to save checkpoints together with logs/statistics
390
+ model,
391
+ cfg.OUTPUT_DIR,
392
+ trainer=weakref.proxy(self),
393
+ )
394
+ self.start_iter = 0
395
+ self.max_iter = cfg.SOLVER.MAX_ITER
396
+ self.cfg = cfg
397
+
398
+ self.register_hooks(self.build_hooks())
399
+
400
+ def resume_or_load(self, resume=True):
401
+ """
402
+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
403
+ a `last_checkpoint` file), resume from the file. Resuming means loading all
404
+ available states (eg. optimizer and scheduler) and update iteration counter
405
+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
406
+
407
+ Otherwise, this is considered as an independent training. The method will load model
408
+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
409
+ from iteration 0.
410
+
411
+ Args:
412
+ resume (bool): whether to do resume or not
413
+ """
414
+ self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
415
+ if resume and self.checkpointer.has_checkpoint():
416
+ # The checkpoint stores the training iteration that just finished, thus we start
417
+ # at the next iteration
418
+ self.start_iter = self.iter + 1
419
+
420
+ def build_hooks(self):
421
+ """
422
+ Build a list of default hooks, including timing, evaluation,
423
+ checkpointing, lr scheduling, precise BN, writing events.
424
+
425
+ Returns:
426
+ list[HookBase]:
427
+ """
428
+ cfg = self.cfg.clone()
429
+ cfg.defrost()
430
+ cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
431
+
432
+ ret = [
433
+ hooks.IterationTimer(),
434
+ hooks.LRScheduler(),
435
+ (
436
+ hooks.PreciseBN(
437
+ # Run at the same freq as (but before) evaluation.
438
+ cfg.TEST.EVAL_PERIOD,
439
+ self.model,
440
+ # Build a new data loader to not affect training
441
+ self.build_train_loader(cfg),
442
+ cfg.TEST.PRECISE_BN.NUM_ITER,
443
+ )
444
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
445
+ else None
446
+ ),
447
+ ]
448
+
449
+ # Do PreciseBN before checkpointer, because it updates the model and need to
450
+ # be saved by checkpointer.
451
+ # This is not always the best: if checkpointing has a different frequency,
452
+ # some checkpoints may have more precise statistics than others.
453
+ if comm.is_main_process():
454
+ ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
455
+
456
+ def test_and_save_results():
457
+ self._last_eval_results = self.test(self.cfg, self.model)
458
+ return self._last_eval_results
459
+
460
+ # Do evaluation after checkpointer, because then if it fails,
461
+ # we can use the saved checkpoint to debug.
462
+ ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
463
+
464
+ if comm.is_main_process():
465
+ # Here the default print/log frequency of each writer is used.
466
+ # run writers in the end, so that evaluation metrics are written
467
+ ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
468
+ return ret
469
+
470
+ def build_writers(self):
471
+ """
472
+ Build a list of writers to be used using :func:`default_writers()`.
473
+ If you'd like a different list of writers, you can overwrite it in
474
+ your trainer.
475
+
476
+ Returns:
477
+ list[EventWriter]: a list of :class:`EventWriter` objects.
478
+ """
479
+ return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
480
+
481
+ def train(self):
482
+ """
483
+ Run training.
484
+
485
+ Returns:
486
+ OrderedDict of results, if evaluation is enabled. Otherwise None.
487
+ """
488
+ super().train(self.start_iter, self.max_iter)
489
+ if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
490
+ assert hasattr(
491
+ self, "_last_eval_results"
492
+ ), "No evaluation results obtained during training!"
493
+ verify_results(self.cfg, self._last_eval_results)
494
+ return self._last_eval_results
495
+
496
+ def run_step(self):
497
+ self._trainer.iter = self.iter
498
+ self._trainer.run_step()
499
+
500
+ def state_dict(self):
501
+ ret = super().state_dict()
502
+ ret["_trainer"] = self._trainer.state_dict()
503
+ return ret
504
+
505
+ def load_state_dict(self, state_dict):
506
+ super().load_state_dict(state_dict)
507
+ self._trainer.load_state_dict(state_dict["_trainer"])
508
+
509
+ @classmethod
510
+ def build_model(cls, cfg):
511
+ """
512
+ Returns:
513
+ torch.nn.Module:
514
+
515
+ It now calls :func:`detectron2.modeling.build_model`.
516
+ Overwrite it if you'd like a different model.
517
+ """
518
+ model = build_model(cfg)
519
+ logger = logging.getLogger(__name__)
520
+ logger.info("Model:\n{}".format(model))
521
+ return model
522
+
523
+ @classmethod
524
+ def build_optimizer(cls, cfg, model):
525
+ """
526
+ Returns:
527
+ torch.optim.Optimizer:
528
+
529
+ It now calls :func:`detectron2.solver.build_optimizer`.
530
+ Overwrite it if you'd like a different optimizer.
531
+ """
532
+ return build_optimizer(cfg, model)
533
+
534
+ @classmethod
535
+ def build_lr_scheduler(cls, cfg, optimizer):
536
+ """
537
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
538
+ Overwrite it if you'd like a different scheduler.
539
+ """
540
+ return build_lr_scheduler(cfg, optimizer)
541
+
542
+ @classmethod
543
+ def build_train_loader(cls, cfg):
544
+ """
545
+ Returns:
546
+ iterable
547
+
548
+ It now calls :func:`detectron2.data.build_detection_train_loader`.
549
+ Overwrite it if you'd like a different data loader.
550
+ """
551
+ return build_detection_train_loader(cfg)
552
+
553
+ @classmethod
554
+ def build_test_loader(cls, cfg, dataset_name):
555
+ """
556
+ Returns:
557
+ iterable
558
+
559
+ It now calls :func:`detectron2.data.build_detection_test_loader`.
560
+ Overwrite it if you'd like a different data loader.
561
+ """
562
+ return build_detection_test_loader(cfg, dataset_name)
563
+
564
+ @classmethod
565
+ def build_evaluator(cls, cfg, dataset_name):
566
+ """
567
+ Returns:
568
+ DatasetEvaluator or None
569
+
570
+ It is not implemented by default.
571
+ """
572
+ raise NotImplementedError(
573
+ """
574
+ If you want DefaultTrainer to automatically run evaluation,
575
+ please implement `build_evaluator()` in subclasses (see train_net.py for example).
576
+ Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
577
+ """
578
+ )
579
+
580
+ @classmethod
581
+ def test(cls, cfg, model, evaluators=None):
582
+ """
583
+ Evaluate the given model. The given model is expected to already contain
584
+ weights to evaluate.
585
+
586
+ Args:
587
+ cfg (CfgNode):
588
+ model (nn.Module):
589
+ evaluators (list[DatasetEvaluator] or None): if None, will call
590
+ :meth:`build_evaluator`. Otherwise, must have the same length as
591
+ ``cfg.DATASETS.TEST``.
592
+
593
+ Returns:
594
+ dict: a dict of result metrics
595
+ """
596
+ logger = logging.getLogger(__name__)
597
+ if isinstance(evaluators, DatasetEvaluator):
598
+ evaluators = [evaluators]
599
+ if evaluators is not None:
600
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
601
+ len(cfg.DATASETS.TEST), len(evaluators)
602
+ )
603
+
604
+ results = OrderedDict()
605
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
606
+ data_loader = cls.build_test_loader(cfg, dataset_name)
607
+ # When evaluators are passed in as arguments,
608
+ # implicitly assume that evaluators can be created before data_loader.
609
+ if evaluators is not None:
610
+ evaluator = evaluators[idx]
611
+ else:
612
+ try:
613
+ evaluator = cls.build_evaluator(cfg, dataset_name)
614
+ except NotImplementedError:
615
+ logger.warn(
616
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
617
+ "or implement its `build_evaluator` method."
618
+ )
619
+ results[dataset_name] = {}
620
+ continue
621
+ results_i = inference_on_dataset(model, data_loader, evaluator)
622
+ results[dataset_name] = results_i
623
+ if comm.is_main_process():
624
+ assert isinstance(
625
+ results_i, dict
626
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
627
+ results_i
628
+ )
629
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
630
+ print_csv_format(results_i)
631
+
632
+ if len(results) == 1:
633
+ results = list(results.values())[0]
634
+ return results
635
+
636
+ @staticmethod
637
+ def auto_scale_workers(cfg, num_workers: int):
638
+ """
639
+ When the config is defined for certain number of workers (according to
640
+ ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
641
+ workers currently in use, returns a new cfg where the total batch size
642
+ is scaled so that the per-GPU batch size stays the same as the
643
+ original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
644
+
645
+ Other config options are also scaled accordingly:
646
+ * training steps and warmup steps are scaled inverse proportionally.
647
+ * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
648
+
649
+ For example, with the original config like the following:
650
+
651
+ .. code-block:: yaml
652
+
653
+ IMS_PER_BATCH: 16
654
+ BASE_LR: 0.1
655
+ REFERENCE_WORLD_SIZE: 8
656
+ MAX_ITER: 5000
657
+ STEPS: (4000,)
658
+ CHECKPOINT_PERIOD: 1000
659
+
660
+ When this config is used on 16 GPUs instead of the reference number 8,
661
+ calling this method will return a new config with:
662
+
663
+ .. code-block:: yaml
664
+
665
+ IMS_PER_BATCH: 32
666
+ BASE_LR: 0.2
667
+ REFERENCE_WORLD_SIZE: 16
668
+ MAX_ITER: 2500
669
+ STEPS: (2000,)
670
+ CHECKPOINT_PERIOD: 500
671
+
672
+ Note that both the original config and this new config can be trained on 16 GPUs.
673
+ It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
674
+
675
+ Returns:
676
+ CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
677
+ """
678
+ old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
679
+ if old_world_size == 0 or old_world_size == num_workers:
680
+ return cfg
681
+ cfg = cfg.clone()
682
+ frozen = cfg.is_frozen()
683
+ cfg.defrost()
684
+
685
+ assert (
686
+ cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
687
+ ), "Invalid REFERENCE_WORLD_SIZE in config!"
688
+ scale = num_workers / old_world_size
689
+ bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
690
+ lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
691
+ max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
692
+ warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
693
+ cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
694
+ cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
695
+ cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
696
+ cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
697
+ logger = logging.getLogger(__name__)
698
+ logger.info(
699
+ f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
700
+ f"max_iter={max_iter}, warmup={warmup_iter}."
701
+ )
702
+
703
+ if frozen:
704
+ cfg.freeze()
705
+ return cfg
706
+
707
+
708
+ # Access basic attributes from the underlying trainer
709
+ for _attr in ["model", "data_loader", "optimizer"]:
710
+ setattr(
711
+ DefaultTrainer,
712
+ _attr,
713
+ property(
714
+ # getter
715
+ lambda self, x=_attr: getattr(self._trainer, x),
716
+ # setter
717
+ lambda self, value, x=_attr: setattr(self._trainer, x, value),
718
+ ),
719
+ )
CatVTON/detectron2/engine/hooks.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import datetime
5
+ import itertools
6
+ import logging
7
+ import math
8
+ import operator
9
+ import os
10
+ import tempfile
11
+ import time
12
+ import warnings
13
+ from collections import Counter
14
+ import torch
15
+ from fvcore.common.checkpoint import Checkpointer
16
+ from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
17
+ from fvcore.common.param_scheduler import ParamScheduler
18
+ from fvcore.common.timer import Timer
19
+ from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
20
+
21
+ import detectron2.utils.comm as comm
22
+ from detectron2.evaluation.testing import flatten_results_dict
23
+ from detectron2.solver import LRMultiplier
24
+ from detectron2.solver import LRScheduler as _LRScheduler
25
+ from detectron2.utils.events import EventStorage, EventWriter
26
+ from detectron2.utils.file_io import PathManager
27
+
28
+ from .train_loop import HookBase
29
+
30
+ __all__ = [
31
+ "CallbackHook",
32
+ "IterationTimer",
33
+ "PeriodicWriter",
34
+ "PeriodicCheckpointer",
35
+ "BestCheckpointer",
36
+ "LRScheduler",
37
+ "AutogradProfiler",
38
+ "EvalHook",
39
+ "PreciseBN",
40
+ "TorchProfiler",
41
+ "TorchMemoryStats",
42
+ ]
43
+
44
+
45
+ """
46
+ Implement some common hooks.
47
+ """
48
+
49
+
50
+ class CallbackHook(HookBase):
51
+ """
52
+ Create a hook using callback functions provided by the user.
53
+ """
54
+
55
+ def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
56
+ """
57
+ Each argument is a function that takes one argument: the trainer.
58
+ """
59
+ self._before_train = before_train
60
+ self._before_step = before_step
61
+ self._after_step = after_step
62
+ self._after_train = after_train
63
+
64
+ def before_train(self):
65
+ if self._before_train:
66
+ self._before_train(self.trainer)
67
+
68
+ def after_train(self):
69
+ if self._after_train:
70
+ self._after_train(self.trainer)
71
+ # The functions may be closures that hold reference to the trainer
72
+ # Therefore, delete them to avoid circular reference.
73
+ del self._before_train, self._after_train
74
+ del self._before_step, self._after_step
75
+
76
+ def before_step(self):
77
+ if self._before_step:
78
+ self._before_step(self.trainer)
79
+
80
+ def after_step(self):
81
+ if self._after_step:
82
+ self._after_step(self.trainer)
83
+
84
+
85
+ class IterationTimer(HookBase):
86
+ """
87
+ Track the time spent for each iteration (each run_step call in the trainer).
88
+ Print a summary in the end of training.
89
+
90
+ This hook uses the time between the call to its :meth:`before_step`
91
+ and :meth:`after_step` methods.
92
+ Under the convention that :meth:`before_step` of all hooks should only
93
+ take negligible amount of time, the :class:`IterationTimer` hook should be
94
+ placed at the beginning of the list of hooks to obtain accurate timing.
95
+ """
96
+
97
+ def __init__(self, warmup_iter=3):
98
+ """
99
+ Args:
100
+ warmup_iter (int): the number of iterations at the beginning to exclude
101
+ from timing.
102
+ """
103
+ self._warmup_iter = warmup_iter
104
+ self._step_timer = Timer()
105
+ self._start_time = time.perf_counter()
106
+ self._total_timer = Timer()
107
+
108
+ def before_train(self):
109
+ self._start_time = time.perf_counter()
110
+ self._total_timer.reset()
111
+ self._total_timer.pause()
112
+
113
+ def after_train(self):
114
+ logger = logging.getLogger(__name__)
115
+ total_time = time.perf_counter() - self._start_time
116
+ total_time_minus_hooks = self._total_timer.seconds()
117
+ hook_time = total_time - total_time_minus_hooks
118
+
119
+ num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter
120
+
121
+ if num_iter > 0 and total_time_minus_hooks > 0:
122
+ # Speed is meaningful only after warmup
123
+ # NOTE this format is parsed by grep in some scripts
124
+ logger.info(
125
+ "Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
126
+ num_iter,
127
+ str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
128
+ total_time_minus_hooks / num_iter,
129
+ )
130
+ )
131
+
132
+ logger.info(
133
+ "Total training time: {} ({} on hooks)".format(
134
+ str(datetime.timedelta(seconds=int(total_time))),
135
+ str(datetime.timedelta(seconds=int(hook_time))),
136
+ )
137
+ )
138
+
139
+ def before_step(self):
140
+ self._step_timer.reset()
141
+ self._total_timer.resume()
142
+
143
+ def after_step(self):
144
+ # +1 because we're in after_step, the current step is done
145
+ # but not yet counted
146
+ iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1
147
+ if iter_done >= self._warmup_iter:
148
+ sec = self._step_timer.seconds()
149
+ self.trainer.storage.put_scalars(time=sec)
150
+ else:
151
+ self._start_time = time.perf_counter()
152
+ self._total_timer.reset()
153
+
154
+ self._total_timer.pause()
155
+
156
+
157
+ class PeriodicWriter(HookBase):
158
+ """
159
+ Write events to EventStorage (by calling ``writer.write()``) periodically.
160
+
161
+ It is executed every ``period`` iterations and after the last iteration.
162
+ Note that ``period`` does not affect how data is smoothed by each writer.
163
+ """
164
+
165
+ def __init__(self, writers, period=20):
166
+ """
167
+ Args:
168
+ writers (list[EventWriter]): a list of EventWriter objects
169
+ period (int):
170
+ """
171
+ self._writers = writers
172
+ for w in writers:
173
+ assert isinstance(w, EventWriter), w
174
+ self._period = period
175
+
176
+ def after_step(self):
177
+ if (self.trainer.iter + 1) % self._period == 0 or (
178
+ self.trainer.iter == self.trainer.max_iter - 1
179
+ ):
180
+ for writer in self._writers:
181
+ writer.write()
182
+
183
+ def after_train(self):
184
+ for writer in self._writers:
185
+ # If any new data is found (e.g. produced by other after_train),
186
+ # write them before closing
187
+ writer.write()
188
+ writer.close()
189
+
190
+
191
+ class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
192
+ """
193
+ Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
194
+
195
+ Note that when used as a hook,
196
+ it is unable to save additional data other than what's defined
197
+ by the given `checkpointer`.
198
+
199
+ It is executed every ``period`` iterations and after the last iteration.
200
+ """
201
+
202
+ def before_train(self):
203
+ self.max_iter = self.trainer.max_iter
204
+
205
+ def after_step(self):
206
+ # No way to use **kwargs
207
+ self.step(self.trainer.iter)
208
+
209
+
210
+ class BestCheckpointer(HookBase):
211
+ """
212
+ Checkpoints best weights based off given metric.
213
+
214
+ This hook should be used in conjunction to and executed after the hook
215
+ that produces the metric, e.g. `EvalHook`.
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ eval_period: int,
221
+ checkpointer: Checkpointer,
222
+ val_metric: str,
223
+ mode: str = "max",
224
+ file_prefix: str = "model_best",
225
+ ) -> None:
226
+ """
227
+ Args:
228
+ eval_period (int): the period `EvalHook` is set to run.
229
+ checkpointer: the checkpointer object used to save checkpoints.
230
+ val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50"
231
+ mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
232
+ maximized or minimized, e.g. for "bbox/AP50" it should be "max"
233
+ file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
234
+ """
235
+ self._logger = logging.getLogger(__name__)
236
+ self._period = eval_period
237
+ self._val_metric = val_metric
238
+ assert mode in [
239
+ "max",
240
+ "min",
241
+ ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
242
+ if mode == "max":
243
+ self._compare = operator.gt
244
+ else:
245
+ self._compare = operator.lt
246
+ self._checkpointer = checkpointer
247
+ self._file_prefix = file_prefix
248
+ self.best_metric = None
249
+ self.best_iter = None
250
+
251
+ def _update_best(self, val, iteration):
252
+ if math.isnan(val) or math.isinf(val):
253
+ return False
254
+ self.best_metric = val
255
+ self.best_iter = iteration
256
+ return True
257
+
258
+ def _best_checking(self):
259
+ metric_tuple = self.trainer.storage.latest().get(self._val_metric)
260
+ if metric_tuple is None:
261
+ self._logger.warning(
262
+ f"Given val metric {self._val_metric} does not seem to be computed/stored."
263
+ "Will not be checkpointing based on it."
264
+ )
265
+ return
266
+ else:
267
+ latest_metric, metric_iter = metric_tuple
268
+
269
+ if self.best_metric is None:
270
+ if self._update_best(latest_metric, metric_iter):
271
+ additional_state = {"iteration": metric_iter}
272
+ self._checkpointer.save(f"{self._file_prefix}", **additional_state)
273
+ self._logger.info(
274
+ f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
275
+ )
276
+ elif self._compare(latest_metric, self.best_metric):
277
+ additional_state = {"iteration": metric_iter}
278
+ self._checkpointer.save(f"{self._file_prefix}", **additional_state)
279
+ self._logger.info(
280
+ f"Saved best model as latest eval score for {self._val_metric} is "
281
+ f"{latest_metric:0.5f}, better than last best score "
282
+ f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
283
+ )
284
+ self._update_best(latest_metric, metric_iter)
285
+ else:
286
+ self._logger.info(
287
+ f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, "
288
+ f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}."
289
+ )
290
+
291
+ def after_step(self):
292
+ # same conditions as `EvalHook`
293
+ next_iter = self.trainer.iter + 1
294
+ if (
295
+ self._period > 0
296
+ and next_iter % self._period == 0
297
+ and next_iter != self.trainer.max_iter
298
+ ):
299
+ self._best_checking()
300
+
301
+ def after_train(self):
302
+ # same conditions as `EvalHook`
303
+ if self.trainer.iter + 1 >= self.trainer.max_iter:
304
+ self._best_checking()
305
+
306
+
307
+ class LRScheduler(HookBase):
308
+ """
309
+ A hook which executes a torch builtin LR scheduler and summarizes the LR.
310
+ It is executed after every iteration.
311
+ """
312
+
313
+ def __init__(self, optimizer=None, scheduler=None):
314
+ """
315
+ Args:
316
+ optimizer (torch.optim.Optimizer):
317
+ scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
318
+ if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
319
+ in the optimizer.
320
+
321
+ If any argument is not given, will try to obtain it from the trainer.
322
+ """
323
+ self._optimizer = optimizer
324
+ self._scheduler = scheduler
325
+
326
+ def before_train(self):
327
+ self._optimizer = self._optimizer or self.trainer.optimizer
328
+ if isinstance(self.scheduler, ParamScheduler):
329
+ self._scheduler = LRMultiplier(
330
+ self._optimizer,
331
+ self.scheduler,
332
+ self.trainer.max_iter,
333
+ last_iter=self.trainer.iter - 1,
334
+ )
335
+ self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
336
+
337
+ @staticmethod
338
+ def get_best_param_group_id(optimizer):
339
+ # NOTE: some heuristics on what LR to summarize
340
+ # summarize the param group with most parameters
341
+ largest_group = max(len(g["params"]) for g in optimizer.param_groups)
342
+
343
+ if largest_group == 1:
344
+ # If all groups have one parameter,
345
+ # then find the most common initial LR, and use it for summary
346
+ lr_count = Counter([g["lr"] for g in optimizer.param_groups])
347
+ lr = lr_count.most_common()[0][0]
348
+ for i, g in enumerate(optimizer.param_groups):
349
+ if g["lr"] == lr:
350
+ return i
351
+ else:
352
+ for i, g in enumerate(optimizer.param_groups):
353
+ if len(g["params"]) == largest_group:
354
+ return i
355
+
356
+ def after_step(self):
357
+ lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
358
+ self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
359
+ self.scheduler.step()
360
+
361
+ @property
362
+ def scheduler(self):
363
+ return self._scheduler or self.trainer.scheduler
364
+
365
+ def state_dict(self):
366
+ if isinstance(self.scheduler, _LRScheduler):
367
+ return self.scheduler.state_dict()
368
+ return {}
369
+
370
+ def load_state_dict(self, state_dict):
371
+ if isinstance(self.scheduler, _LRScheduler):
372
+ logger = logging.getLogger(__name__)
373
+ logger.info("Loading scheduler from state_dict ...")
374
+ self.scheduler.load_state_dict(state_dict)
375
+
376
+
377
+ class TorchProfiler(HookBase):
378
+ """
379
+ A hook which runs `torch.profiler.profile`.
380
+
381
+ Examples:
382
+ ::
383
+ hooks.TorchProfiler(
384
+ lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
385
+ )
386
+
387
+ The above example will run the profiler for iteration 10~20 and dump
388
+ results to ``OUTPUT_DIR``. We did not profile the first few iterations
389
+ because they are typically slower than the rest.
390
+ The result files can be loaded in the ``chrome://tracing`` page in chrome browser,
391
+ and the tensorboard visualizations can be visualized using
392
+ ``tensorboard --logdir OUTPUT_DIR/log``
393
+ """
394
+
395
+ def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):
396
+ """
397
+ Args:
398
+ enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
399
+ and returns whether to enable the profiler.
400
+ It will be called once every step, and can be used to select which steps to profile.
401
+ output_dir (str): the output directory to dump tracing files.
402
+ activities (iterable): same as in `torch.profiler.profile`.
403
+ save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/
404
+ """
405
+ self._enable_predicate = enable_predicate
406
+ self._activities = activities
407
+ self._output_dir = output_dir
408
+ self._save_tensorboard = save_tensorboard
409
+
410
+ def before_step(self):
411
+ if self._enable_predicate(self.trainer):
412
+ if self._save_tensorboard:
413
+ on_trace_ready = torch.profiler.tensorboard_trace_handler(
414
+ os.path.join(
415
+ self._output_dir,
416
+ "log",
417
+ "profiler-tensorboard-iter{}".format(self.trainer.iter),
418
+ ),
419
+ f"worker{comm.get_rank()}",
420
+ )
421
+ else:
422
+ on_trace_ready = None
423
+ self._profiler = torch.profiler.profile(
424
+ activities=self._activities,
425
+ on_trace_ready=on_trace_ready,
426
+ record_shapes=True,
427
+ profile_memory=True,
428
+ with_stack=True,
429
+ with_flops=True,
430
+ )
431
+ self._profiler.__enter__()
432
+ else:
433
+ self._profiler = None
434
+
435
+ def after_step(self):
436
+ if self._profiler is None:
437
+ return
438
+ self._profiler.__exit__(None, None, None)
439
+ if not self._save_tensorboard:
440
+ PathManager.mkdirs(self._output_dir)
441
+ out_file = os.path.join(
442
+ self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
443
+ )
444
+ if "://" not in out_file:
445
+ self._profiler.export_chrome_trace(out_file)
446
+ else:
447
+ # Support non-posix filesystems
448
+ with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
449
+ tmp_file = os.path.join(d, "tmp.json")
450
+ self._profiler.export_chrome_trace(tmp_file)
451
+ with open(tmp_file) as f:
452
+ content = f.read()
453
+ with PathManager.open(out_file, "w") as f:
454
+ f.write(content)
455
+
456
+
457
+ class AutogradProfiler(TorchProfiler):
458
+ """
459
+ A hook which runs `torch.autograd.profiler.profile`.
460
+
461
+ Examples:
462
+ ::
463
+ hooks.AutogradProfiler(
464
+ lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
465
+ )
466
+
467
+ The above example will run the profiler for iteration 10~20 and dump
468
+ results to ``OUTPUT_DIR``. We did not profile the first few iterations
469
+ because they are typically slower than the rest.
470
+ The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
471
+
472
+ Note:
473
+ When used together with NCCL on older version of GPUs,
474
+ autograd profiler may cause deadlock because it unnecessarily allocates
475
+ memory on every device it sees. The memory management calls, if
476
+ interleaved with NCCL calls, lead to deadlock on GPUs that do not
477
+ support ``cudaLaunchCooperativeKernelMultiDevice``.
478
+ """
479
+
480
+ def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
481
+ """
482
+ Args:
483
+ enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
484
+ and returns whether to enable the profiler.
485
+ It will be called once every step, and can be used to select which steps to profile.
486
+ output_dir (str): the output directory to dump tracing files.
487
+ use_cuda (bool): same as in `torch.autograd.profiler.profile`.
488
+ """
489
+ warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.")
490
+ self._enable_predicate = enable_predicate
491
+ self._use_cuda = use_cuda
492
+ self._output_dir = output_dir
493
+
494
+ def before_step(self):
495
+ if self._enable_predicate(self.trainer):
496
+ self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
497
+ self._profiler.__enter__()
498
+ else:
499
+ self._profiler = None
500
+
501
+
502
+ class EvalHook(HookBase):
503
+ """
504
+ Run an evaluation function periodically, and at the end of training.
505
+
506
+ It is executed every ``eval_period`` iterations and after the last iteration.
507
+ """
508
+
509
+ def __init__(self, eval_period, eval_function, eval_after_train=True):
510
+ """
511
+ Args:
512
+ eval_period (int): the period to run `eval_function`. Set to 0 to
513
+ not evaluate periodically (but still evaluate after the last iteration
514
+ if `eval_after_train` is True).
515
+ eval_function (callable): a function which takes no arguments, and
516
+ returns a nested dict of evaluation metrics.
517
+ eval_after_train (bool): whether to evaluate after the last iteration
518
+
519
+ Note:
520
+ This hook must be enabled in all or none workers.
521
+ If you would like only certain workers to perform evaluation,
522
+ give other workers a no-op function (`eval_function=lambda: None`).
523
+ """
524
+ self._period = eval_period
525
+ self._func = eval_function
526
+ self._eval_after_train = eval_after_train
527
+
528
+ def _do_eval(self):
529
+ results = self._func()
530
+
531
+ if results:
532
+ assert isinstance(
533
+ results, dict
534
+ ), "Eval function must return a dict. Got {} instead.".format(results)
535
+
536
+ flattened_results = flatten_results_dict(results)
537
+ for k, v in flattened_results.items():
538
+ try:
539
+ v = float(v)
540
+ except Exception as e:
541
+ raise ValueError(
542
+ "[EvalHook] eval_function should return a nested dict of float. "
543
+ "Got '{}: {}' instead.".format(k, v)
544
+ ) from e
545
+ self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
546
+
547
+ # Evaluation may take different time among workers.
548
+ # A barrier make them start the next iteration together.
549
+ comm.synchronize()
550
+
551
+ def after_step(self):
552
+ next_iter = self.trainer.iter + 1
553
+ if self._period > 0 and next_iter % self._period == 0:
554
+ # do the last eval in after_train
555
+ if next_iter != self.trainer.max_iter:
556
+ self._do_eval()
557
+
558
+ def after_train(self):
559
+ # This condition is to prevent the eval from running after a failed training
560
+ if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter:
561
+ self._do_eval()
562
+ # func is likely a closure that holds reference to the trainer
563
+ # therefore we clean it to avoid circular reference in the end
564
+ del self._func
565
+
566
+
567
+ class PreciseBN(HookBase):
568
+ """
569
+ The standard implementation of BatchNorm uses EMA in inference, which is
570
+ sometimes suboptimal.
571
+ This class computes the true average of statistics rather than the moving average,
572
+ and put true averages to every BN layer in the given model.
573
+
574
+ It is executed every ``period`` iterations and after the last iteration.
575
+ """
576
+
577
+ def __init__(self, period, model, data_loader, num_iter):
578
+ """
579
+ Args:
580
+ period (int): the period this hook is run, or 0 to not run during training.
581
+ The hook will always run in the end of training.
582
+ model (nn.Module): a module whose all BN layers in training mode will be
583
+ updated by precise BN.
584
+ Note that user is responsible for ensuring the BN layers to be
585
+ updated are in training mode when this hook is triggered.
586
+ data_loader (iterable): it will produce data to be run by `model(data)`.
587
+ num_iter (int): number of iterations used to compute the precise
588
+ statistics.
589
+ """
590
+ self._logger = logging.getLogger(__name__)
591
+ if len(get_bn_modules(model)) == 0:
592
+ self._logger.info(
593
+ "PreciseBN is disabled because model does not contain BN layers in training mode."
594
+ )
595
+ self._disabled = True
596
+ return
597
+
598
+ self._model = model
599
+ self._data_loader = data_loader
600
+ self._num_iter = num_iter
601
+ self._period = period
602
+ self._disabled = False
603
+
604
+ self._data_iter = None
605
+
606
+ def after_step(self):
607
+ next_iter = self.trainer.iter + 1
608
+ is_final = next_iter == self.trainer.max_iter
609
+ if is_final or (self._period > 0 and next_iter % self._period == 0):
610
+ self.update_stats()
611
+
612
+ def update_stats(self):
613
+ """
614
+ Update the model with precise statistics. Users can manually call this method.
615
+ """
616
+ if self._disabled:
617
+ return
618
+
619
+ if self._data_iter is None:
620
+ self._data_iter = iter(self._data_loader)
621
+
622
+ def data_loader():
623
+ for num_iter in itertools.count(1):
624
+ if num_iter % 100 == 0:
625
+ self._logger.info(
626
+ "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
627
+ )
628
+ # This way we can reuse the same iterator
629
+ yield next(self._data_iter)
630
+
631
+ with EventStorage(): # capture events in a new storage to discard them
632
+ self._logger.info(
633
+ "Running precise-BN for {} iterations... ".format(self._num_iter)
634
+ + "Note that this could produce different statistics every time."
635
+ )
636
+ update_bn_stats(self._model, data_loader(), self._num_iter)
637
+
638
+
639
+ class TorchMemoryStats(HookBase):
640
+ """
641
+ Writes pytorch's cuda memory statistics periodically.
642
+ """
643
+
644
+ def __init__(self, period=20, max_runs=10):
645
+ """
646
+ Args:
647
+ period (int): Output stats each 'period' iterations
648
+ max_runs (int): Stop the logging after 'max_runs'
649
+ """
650
+
651
+ self._logger = logging.getLogger(__name__)
652
+ self._period = period
653
+ self._max_runs = max_runs
654
+ self._runs = 0
655
+
656
+ def after_step(self):
657
+ if self._runs > self._max_runs:
658
+ return
659
+
660
+ if (self.trainer.iter + 1) % self._period == 0 or (
661
+ self.trainer.iter == self.trainer.max_iter - 1
662
+ ):
663
+ if torch.cuda.is_available():
664
+ max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0
665
+ reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0
666
+ max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
667
+ allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0
668
+
669
+ self._logger.info(
670
+ (
671
+ " iter: {} "
672
+ " max_reserved_mem: {:.0f}MB "
673
+ " reserved_mem: {:.0f}MB "
674
+ " max_allocated_mem: {:.0f}MB "
675
+ " allocated_mem: {:.0f}MB "
676
+ ).format(
677
+ self.trainer.iter,
678
+ max_reserved_mb,
679
+ reserved_mb,
680
+ max_allocated_mb,
681
+ allocated_mb,
682
+ )
683
+ )
684
+
685
+ self._runs += 1
686
+ if self._runs == self._max_runs:
687
+ mem_summary = torch.cuda.memory_summary()
688
+ self._logger.info("\n" + mem_summary)
689
+
690
+ torch.cuda.reset_peak_memory_stats()
CatVTON/detectron2/engine/launch.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from datetime import timedelta
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.multiprocessing as mp
7
+
8
+ from detectron2.utils import comm
9
+
10
+ __all__ = ["DEFAULT_TIMEOUT", "launch"]
11
+
12
+ DEFAULT_TIMEOUT = timedelta(minutes=30)
13
+
14
+
15
+ def _find_free_port():
16
+ import socket
17
+
18
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
19
+ # Binding to port 0 will cause the OS to find an available port for us
20
+ sock.bind(("", 0))
21
+ port = sock.getsockname()[1]
22
+ sock.close()
23
+ # NOTE: there is still a chance the port could be taken by other processes.
24
+ return port
25
+
26
+
27
+ def launch(
28
+ main_func,
29
+ # Should be num_processes_per_machine, but kept for compatibility.
30
+ num_gpus_per_machine,
31
+ num_machines=1,
32
+ machine_rank=0,
33
+ dist_url=None,
34
+ args=(),
35
+ timeout=DEFAULT_TIMEOUT,
36
+ ):
37
+ """
38
+ Launch multi-process or distributed training.
39
+ This function must be called on all machines involved in the training.
40
+ It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
41
+
42
+ Args:
43
+ main_func: a function that will be called by `main_func(*args)`
44
+ num_gpus_per_machine (int): number of processes per machine. When
45
+ using GPUs, this should be the number of GPUs.
46
+ num_machines (int): the total number of machines
47
+ machine_rank (int): the rank of this machine
48
+ dist_url (str): url to connect to for distributed jobs, including protocol
49
+ e.g. "tcp://127.0.0.1:8686".
50
+ Can be set to "auto" to automatically select a free port on localhost
51
+ timeout (timedelta): timeout of the distributed workers
52
+ args (tuple): arguments passed to main_func
53
+ """
54
+ world_size = num_machines * num_gpus_per_machine
55
+ if world_size > 1:
56
+ # https://github.com/pytorch/pytorch/pull/14391
57
+ # TODO prctl in spawned processes
58
+
59
+ if dist_url == "auto":
60
+ assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
61
+ port = _find_free_port()
62
+ dist_url = f"tcp://127.0.0.1:{port}"
63
+ if num_machines > 1 and dist_url.startswith("file://"):
64
+ logger = logging.getLogger(__name__)
65
+ logger.warning(
66
+ "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
67
+ )
68
+
69
+ mp.start_processes(
70
+ _distributed_worker,
71
+ nprocs=num_gpus_per_machine,
72
+ args=(
73
+ main_func,
74
+ world_size,
75
+ num_gpus_per_machine,
76
+ machine_rank,
77
+ dist_url,
78
+ args,
79
+ timeout,
80
+ ),
81
+ daemon=False,
82
+ )
83
+ else:
84
+ main_func(*args)
85
+
86
+
87
+ def _distributed_worker(
88
+ local_rank,
89
+ main_func,
90
+ world_size,
91
+ num_gpus_per_machine,
92
+ machine_rank,
93
+ dist_url,
94
+ args,
95
+ timeout=DEFAULT_TIMEOUT,
96
+ ):
97
+ has_gpu = torch.cuda.is_available()
98
+ if has_gpu:
99
+ assert num_gpus_per_machine <= torch.cuda.device_count()
100
+ global_rank = machine_rank * num_gpus_per_machine + local_rank
101
+ try:
102
+ dist.init_process_group(
103
+ backend="NCCL" if has_gpu else "GLOO",
104
+ init_method=dist_url,
105
+ world_size=world_size,
106
+ rank=global_rank,
107
+ timeout=timeout,
108
+ )
109
+ except Exception as e:
110
+ logger = logging.getLogger(__name__)
111
+ logger.error("Process group URL: {}".format(dist_url))
112
+ raise e
113
+
114
+ # Setup the local process group.
115
+ comm.create_local_process_group(num_gpus_per_machine)
116
+ if has_gpu:
117
+ torch.cuda.set_device(local_rank)
118
+
119
+ # synchronize is needed here to prevent a possible timeout after calling init_process_group
120
+ # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
121
+ comm.synchronize()
122
+
123
+ main_func(*args)
CatVTON/detectron2/engine/train_loop.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ import concurrent.futures
4
+ import logging
5
+ import numpy as np
6
+ import time
7
+ import weakref
8
+ from typing import List, Mapping, Optional
9
+ import torch
10
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
11
+
12
+ import detectron2.utils.comm as comm
13
+ from detectron2.utils.events import EventStorage, get_event_storage
14
+ from detectron2.utils.logger import _log_api_usage
15
+
16
+ __all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"]
17
+
18
+
19
+ class HookBase:
20
+ """
21
+ Base class for hooks that can be registered with :class:`TrainerBase`.
22
+
23
+ Each hook can implement 4 methods. The way they are called is demonstrated
24
+ in the following snippet:
25
+ ::
26
+ hook.before_train()
27
+ for iter in range(start_iter, max_iter):
28
+ hook.before_step()
29
+ trainer.run_step()
30
+ hook.after_step()
31
+ iter += 1
32
+ hook.after_train()
33
+
34
+ Notes:
35
+ 1. In the hook method, users can access ``self.trainer`` to access more
36
+ properties about the context (e.g., model, current iteration, or config
37
+ if using :class:`DefaultTrainer`).
38
+
39
+ 2. A hook that does something in :meth:`before_step` can often be
40
+ implemented equivalently in :meth:`after_step`.
41
+ If the hook takes non-trivial time, it is strongly recommended to
42
+ implement the hook in :meth:`after_step` instead of :meth:`before_step`.
43
+ The convention is that :meth:`before_step` should only take negligible time.
44
+
45
+ Following this convention will allow hooks that do care about the difference
46
+ between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
47
+ function properly.
48
+
49
+ """
50
+
51
+ trainer: "TrainerBase" = None
52
+ """
53
+ A weak reference to the trainer object. Set by the trainer when the hook is registered.
54
+ """
55
+
56
+ def before_train(self):
57
+ """
58
+ Called before the first iteration.
59
+ """
60
+ pass
61
+
62
+ def after_train(self):
63
+ """
64
+ Called after the last iteration.
65
+ """
66
+ pass
67
+
68
+ def before_step(self):
69
+ """
70
+ Called before each iteration.
71
+ """
72
+ pass
73
+
74
+ def after_backward(self):
75
+ """
76
+ Called after the backward pass of each iteration.
77
+ """
78
+ pass
79
+
80
+ def after_step(self):
81
+ """
82
+ Called after each iteration.
83
+ """
84
+ pass
85
+
86
+ def state_dict(self):
87
+ """
88
+ Hooks are stateless by default, but can be made checkpointable by
89
+ implementing `state_dict` and `load_state_dict`.
90
+ """
91
+ return {}
92
+
93
+
94
+ class TrainerBase:
95
+ """
96
+ Base class for iterative trainer with hooks.
97
+
98
+ The only assumption we made here is: the training runs in a loop.
99
+ A subclass can implement what the loop is.
100
+ We made no assumptions about the existence of dataloader, optimizer, model, etc.
101
+
102
+ Attributes:
103
+ iter(int): the current iteration.
104
+
105
+ start_iter(int): The iteration to start with.
106
+ By convention the minimum possible value is 0.
107
+
108
+ max_iter(int): The iteration to end training.
109
+
110
+ storage(EventStorage): An EventStorage that's opened during the course of training.
111
+ """
112
+
113
+ def __init__(self) -> None:
114
+ self._hooks: List[HookBase] = []
115
+ self.iter: int = 0
116
+ self.start_iter: int = 0
117
+ self.max_iter: int
118
+ self.storage: EventStorage
119
+ _log_api_usage("trainer." + self.__class__.__name__)
120
+
121
+ def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
122
+ """
123
+ Register hooks to the trainer. The hooks are executed in the order
124
+ they are registered.
125
+
126
+ Args:
127
+ hooks (list[Optional[HookBase]]): list of hooks
128
+ """
129
+ hooks = [h for h in hooks if h is not None]
130
+ for h in hooks:
131
+ assert isinstance(h, HookBase)
132
+ # To avoid circular reference, hooks and trainer cannot own each other.
133
+ # This normally does not matter, but will cause memory leak if the
134
+ # involved objects contain __del__:
135
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
136
+ h.trainer = weakref.proxy(self)
137
+ self._hooks.extend(hooks)
138
+
139
+ def train(self, start_iter: int, max_iter: int):
140
+ """
141
+ Args:
142
+ start_iter, max_iter (int): See docs above
143
+ """
144
+ logger = logging.getLogger(__name__)
145
+ logger.info("Starting training from iteration {}".format(start_iter))
146
+
147
+ self.iter = self.start_iter = start_iter
148
+ self.max_iter = max_iter
149
+
150
+ with EventStorage(start_iter) as self.storage:
151
+ try:
152
+ self.before_train()
153
+ for self.iter in range(start_iter, max_iter):
154
+ self.before_step()
155
+ self.run_step()
156
+ self.after_step()
157
+ # self.iter == max_iter can be used by `after_train` to
158
+ # tell whether the training successfully finished or failed
159
+ # due to exceptions.
160
+ self.iter += 1
161
+ except Exception:
162
+ logger.exception("Exception during training:")
163
+ raise
164
+ finally:
165
+ self.after_train()
166
+
167
+ def before_train(self):
168
+ for h in self._hooks:
169
+ h.before_train()
170
+
171
+ def after_train(self):
172
+ self.storage.iter = self.iter
173
+ for h in self._hooks:
174
+ h.after_train()
175
+
176
+ def before_step(self):
177
+ # Maintain the invariant that storage.iter == trainer.iter
178
+ # for the entire execution of each step
179
+ self.storage.iter = self.iter
180
+
181
+ for h in self._hooks:
182
+ h.before_step()
183
+
184
+ def after_backward(self):
185
+ for h in self._hooks:
186
+ h.after_backward()
187
+
188
+ def after_step(self):
189
+ for h in self._hooks:
190
+ h.after_step()
191
+
192
+ def run_step(self):
193
+ raise NotImplementedError
194
+
195
+ def state_dict(self):
196
+ ret = {"iteration": self.iter}
197
+ hooks_state = {}
198
+ for h in self._hooks:
199
+ sd = h.state_dict()
200
+ if sd:
201
+ name = type(h).__qualname__
202
+ if name in hooks_state:
203
+ # TODO handle repetitive stateful hooks
204
+ continue
205
+ hooks_state[name] = sd
206
+ if hooks_state:
207
+ ret["hooks"] = hooks_state
208
+ return ret
209
+
210
+ def load_state_dict(self, state_dict):
211
+ logger = logging.getLogger(__name__)
212
+ self.iter = state_dict["iteration"]
213
+ for key, value in state_dict.get("hooks", {}).items():
214
+ for h in self._hooks:
215
+ try:
216
+ name = type(h).__qualname__
217
+ except AttributeError:
218
+ continue
219
+ if name == key:
220
+ h.load_state_dict(value)
221
+ break
222
+ else:
223
+ logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
224
+
225
+
226
+ class SimpleTrainer(TrainerBase):
227
+ """
228
+ A simple trainer for the most common type of task:
229
+ single-cost single-optimizer single-data-source iterative optimization,
230
+ optionally using data-parallelism.
231
+ It assumes that every step, you:
232
+
233
+ 1. Compute the loss with a data from the data_loader.
234
+ 2. Compute the gradients with the above loss.
235
+ 3. Update the model with the optimizer.
236
+
237
+ All other tasks during training (checkpointing, logging, evaluation, LR schedule)
238
+ are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
239
+
240
+ If you want to do anything fancier than this,
241
+ either subclass TrainerBase and implement your own `run_step`,
242
+ or write your own training loop.
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ model,
248
+ data_loader,
249
+ optimizer,
250
+ gather_metric_period=1,
251
+ zero_grad_before_forward=False,
252
+ async_write_metrics=False,
253
+ ):
254
+ """
255
+ Args:
256
+ model: a torch Module. Takes a data from data_loader and returns a
257
+ dict of losses.
258
+ data_loader: an iterable. Contains data to be used to call model.
259
+ optimizer: a torch optimizer.
260
+ gather_metric_period: an int. Every gather_metric_period iterations
261
+ the metrics are gathered from all the ranks to rank 0 and logged.
262
+ zero_grad_before_forward: whether to zero the gradients before the forward.
263
+ async_write_metrics: bool. If True, then write metrics asynchronously to improve
264
+ training speed
265
+ """
266
+ super().__init__()
267
+
268
+ """
269
+ We set the model to training mode in the trainer.
270
+ However it's valid to train a model that's in eval mode.
271
+ If you want your model (or a submodule of it) to behave
272
+ like evaluation during training, you can overwrite its train() method.
273
+ """
274
+ model.train()
275
+
276
+ self.model = model
277
+ self.data_loader = data_loader
278
+ # to access the data loader iterator, call `self._data_loader_iter`
279
+ self._data_loader_iter_obj = None
280
+ self.optimizer = optimizer
281
+ self.gather_metric_period = gather_metric_period
282
+ self.zero_grad_before_forward = zero_grad_before_forward
283
+ self.async_write_metrics = async_write_metrics
284
+ # create a thread pool that can execute non critical logic in run_step asynchronically
285
+ # use only 1 worker so tasks will be executred in order of submitting.
286
+ self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
287
+
288
+ def run_step(self):
289
+ """
290
+ Implement the standard training logic described above.
291
+ """
292
+ assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
293
+ start = time.perf_counter()
294
+ """
295
+ If you want to do something with the data, you can wrap the dataloader.
296
+ """
297
+ data = next(self._data_loader_iter)
298
+ data_time = time.perf_counter() - start
299
+
300
+ if self.zero_grad_before_forward:
301
+ """
302
+ If you need to accumulate gradients or do something similar, you can
303
+ wrap the optimizer with your custom `zero_grad()` method.
304
+ """
305
+ self.optimizer.zero_grad()
306
+
307
+ """
308
+ If you want to do something with the losses, you can wrap the model.
309
+ """
310
+ loss_dict = self.model(data)
311
+ if isinstance(loss_dict, torch.Tensor):
312
+ losses = loss_dict
313
+ loss_dict = {"total_loss": loss_dict}
314
+ else:
315
+ losses = sum(loss_dict.values())
316
+ if not self.zero_grad_before_forward:
317
+ """
318
+ If you need to accumulate gradients or do something similar, you can
319
+ wrap the optimizer with your custom `zero_grad()` method.
320
+ """
321
+ self.optimizer.zero_grad()
322
+ losses.backward()
323
+
324
+ self.after_backward()
325
+
326
+ if self.async_write_metrics:
327
+ # write metrics asynchronically
328
+ self.concurrent_executor.submit(
329
+ self._write_metrics, loss_dict, data_time, iter=self.iter
330
+ )
331
+ else:
332
+ self._write_metrics(loss_dict, data_time)
333
+
334
+ """
335
+ If you need gradient clipping/scaling or other processing, you can
336
+ wrap the optimizer with your custom `step()` method. But it is
337
+ suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
338
+ """
339
+ self.optimizer.step()
340
+
341
+ @property
342
+ def _data_loader_iter(self):
343
+ # only create the data loader iterator when it is used
344
+ if self._data_loader_iter_obj is None:
345
+ self._data_loader_iter_obj = iter(self.data_loader)
346
+ return self._data_loader_iter_obj
347
+
348
+ def reset_data_loader(self, data_loader_builder):
349
+ """
350
+ Delete and replace the current data loader with a new one, which will be created
351
+ by calling `data_loader_builder` (without argument).
352
+ """
353
+ del self.data_loader
354
+ data_loader = data_loader_builder()
355
+ self.data_loader = data_loader
356
+ self._data_loader_iter_obj = None
357
+
358
+ def _write_metrics(
359
+ self,
360
+ loss_dict: Mapping[str, torch.Tensor],
361
+ data_time: float,
362
+ prefix: str = "",
363
+ iter: Optional[int] = None,
364
+ ) -> None:
365
+ logger = logging.getLogger(__name__)
366
+
367
+ iter = self.iter if iter is None else iter
368
+ if (iter + 1) % self.gather_metric_period == 0:
369
+ try:
370
+ SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix)
371
+ except Exception:
372
+ logger.exception("Exception in writing metrics: ")
373
+ raise
374
+
375
+ @staticmethod
376
+ def write_metrics(
377
+ loss_dict: Mapping[str, torch.Tensor],
378
+ data_time: float,
379
+ cur_iter: int,
380
+ prefix: str = "",
381
+ ) -> None:
382
+ """
383
+ Args:
384
+ loss_dict (dict): dict of scalar losses
385
+ data_time (float): time taken by the dataloader iteration
386
+ prefix (str): prefix for logging keys
387
+ """
388
+ metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
389
+ metrics_dict["data_time"] = data_time
390
+
391
+ storage = get_event_storage()
392
+ # Keep track of data time per rank
393
+ storage.put_scalar("rank_data_time", data_time, cur_iter=cur_iter)
394
+
395
+ # Gather metrics among all workers for logging
396
+ # This assumes we do DDP-style training, which is currently the only
397
+ # supported method in detectron2.
398
+ all_metrics_dict = comm.gather(metrics_dict)
399
+
400
+ if comm.is_main_process():
401
+ # data_time among workers can have high variance. The actual latency
402
+ # caused by data_time is the maximum among workers.
403
+ data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
404
+ storage.put_scalar("data_time", data_time, cur_iter=cur_iter)
405
+
406
+ # average the rest metrics
407
+ metrics_dict = {
408
+ k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
409
+ }
410
+ total_losses_reduced = sum(metrics_dict.values())
411
+ if not np.isfinite(total_losses_reduced):
412
+ raise FloatingPointError(
413
+ f"Loss became infinite or NaN at iteration={cur_iter}!\n"
414
+ f"loss_dict = {metrics_dict}"
415
+ )
416
+
417
+ storage.put_scalar(
418
+ "{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter
419
+ )
420
+ if len(metrics_dict) > 1:
421
+ storage.put_scalars(cur_iter=cur_iter, **metrics_dict)
422
+
423
+ def state_dict(self):
424
+ ret = super().state_dict()
425
+ ret["optimizer"] = self.optimizer.state_dict()
426
+ return ret
427
+
428
+ def load_state_dict(self, state_dict):
429
+ super().load_state_dict(state_dict)
430
+ self.optimizer.load_state_dict(state_dict["optimizer"])
431
+
432
+ def after_train(self):
433
+ super().after_train()
434
+ self.concurrent_executor.shutdown(wait=True)
435
+
436
+
437
+ class AMPTrainer(SimpleTrainer):
438
+ """
439
+ Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision
440
+ in the training loop.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ model,
446
+ data_loader,
447
+ optimizer,
448
+ gather_metric_period=1,
449
+ zero_grad_before_forward=False,
450
+ grad_scaler=None,
451
+ precision: torch.dtype = torch.float16,
452
+ log_grad_scaler: bool = False,
453
+ async_write_metrics=False,
454
+ ):
455
+ """
456
+ Args:
457
+ model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward,
458
+ async_write_metrics: same as in :class:`SimpleTrainer`.
459
+ grad_scaler: torch GradScaler to automatically scale gradients.
460
+ precision: torch.dtype as the target precision to cast to in computations
461
+ """
462
+ unsupported = "AMPTrainer does not support single-process multi-device training!"
463
+ if isinstance(model, DistributedDataParallel):
464
+ assert not (model.device_ids and len(model.device_ids) > 1), unsupported
465
+ assert not isinstance(model, DataParallel), unsupported
466
+
467
+ super().__init__(
468
+ model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward
469
+ )
470
+
471
+ if grad_scaler is None:
472
+ from torch.cuda.amp import GradScaler
473
+
474
+ grad_scaler = GradScaler()
475
+ self.grad_scaler = grad_scaler
476
+ self.precision = precision
477
+ self.log_grad_scaler = log_grad_scaler
478
+
479
+ def run_step(self):
480
+ """
481
+ Implement the AMP training logic.
482
+ """
483
+ assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
484
+ assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
485
+ from torch.cuda.amp import autocast
486
+
487
+ start = time.perf_counter()
488
+ data = next(self._data_loader_iter)
489
+ data_time = time.perf_counter() - start
490
+
491
+ if self.zero_grad_before_forward:
492
+ self.optimizer.zero_grad()
493
+ with autocast(dtype=self.precision):
494
+ loss_dict = self.model(data)
495
+ if isinstance(loss_dict, torch.Tensor):
496
+ losses = loss_dict
497
+ loss_dict = {"total_loss": loss_dict}
498
+ else:
499
+ losses = sum(loss_dict.values())
500
+
501
+ if not self.zero_grad_before_forward:
502
+ self.optimizer.zero_grad()
503
+
504
+ self.grad_scaler.scale(losses).backward()
505
+
506
+ if self.log_grad_scaler:
507
+ storage = get_event_storage()
508
+ storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale())
509
+
510
+ self.after_backward()
511
+
512
+ if self.async_write_metrics:
513
+ # write metrics asynchronically
514
+ self.concurrent_executor.submit(
515
+ self._write_metrics, loss_dict, data_time, iter=self.iter
516
+ )
517
+ else:
518
+ self._write_metrics(loss_dict, data_time)
519
+
520
+ self.grad_scaler.step(self.optimizer)
521
+ self.grad_scaler.update()
522
+
523
+ def state_dict(self):
524
+ ret = super().state_dict()
525
+ ret["grad_scaler"] = self.grad_scaler.state_dict()
526
+ return ret
527
+
528
+ def load_state_dict(self, state_dict):
529
+ super().load_state_dict(state_dict)
530
+ self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
CatVTON/detectron2/modeling/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from detectron2.layers import ShapeSpec
3
+
4
+ from .anchor_generator import build_anchor_generator, ANCHOR_GENERATOR_REGISTRY
5
+ from .backbone import (
6
+ BACKBONE_REGISTRY,
7
+ FPN,
8
+ Backbone,
9
+ ResNet,
10
+ ResNetBlockBase,
11
+ build_backbone,
12
+ build_resnet_backbone,
13
+ make_stage,
14
+ ViT,
15
+ SimpleFeaturePyramid,
16
+ get_vit_lr_decay_rate,
17
+ MViT,
18
+ SwinTransformer,
19
+ )
20
+ from .meta_arch import (
21
+ META_ARCH_REGISTRY,
22
+ SEM_SEG_HEADS_REGISTRY,
23
+ GeneralizedRCNN,
24
+ PanopticFPN,
25
+ ProposalNetwork,
26
+ RetinaNet,
27
+ SemanticSegmentor,
28
+ build_model,
29
+ build_sem_seg_head,
30
+ FCOS,
31
+ )
32
+ from .postprocessing import detector_postprocess
33
+ from .proposal_generator import (
34
+ PROPOSAL_GENERATOR_REGISTRY,
35
+ build_proposal_generator,
36
+ RPN_HEAD_REGISTRY,
37
+ build_rpn_head,
38
+ )
39
+ from .roi_heads import (
40
+ ROI_BOX_HEAD_REGISTRY,
41
+ ROI_HEADS_REGISTRY,
42
+ ROI_KEYPOINT_HEAD_REGISTRY,
43
+ ROI_MASK_HEAD_REGISTRY,
44
+ ROIHeads,
45
+ StandardROIHeads,
46
+ BaseMaskRCNNHead,
47
+ BaseKeypointRCNNHead,
48
+ FastRCNNOutputLayers,
49
+ build_box_head,
50
+ build_keypoint_head,
51
+ build_mask_head,
52
+ build_roi_heads,
53
+ )
54
+ from .test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA
55
+ from .mmdet_wrapper import MMDetBackbone, MMDetDetector
56
+
57
+ _EXCLUDE = {"ShapeSpec"}
58
+ __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
59
+
60
+
61
+ from detectron2.utils.env import fixup_module_metadata
62
+
63
+ fixup_module_metadata(__name__, globals(), __all__)
64
+ del fixup_module_metadata
CatVTON/detectron2/modeling/anchor_generator.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import collections
3
+ import math
4
+ from typing import List
5
+ import torch
6
+ from torch import nn
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.layers import ShapeSpec, move_device_like
10
+ from detectron2.structures import Boxes, RotatedBoxes
11
+ from detectron2.utils.registry import Registry
12
+
13
+ ANCHOR_GENERATOR_REGISTRY = Registry("ANCHOR_GENERATOR")
14
+ ANCHOR_GENERATOR_REGISTRY.__doc__ = """
15
+ Registry for modules that creates object detection anchors for feature maps.
16
+
17
+ The registered object will be called with `obj(cfg, input_shape)`.
18
+ """
19
+
20
+
21
+ class BufferList(nn.Module):
22
+ """
23
+ Similar to nn.ParameterList, but for buffers
24
+ """
25
+
26
+ def __init__(self, buffers):
27
+ super().__init__()
28
+ for i, buffer in enumerate(buffers):
29
+ # Use non-persistent buffer so the values are not saved in checkpoint
30
+ self.register_buffer(str(i), buffer, persistent=False)
31
+
32
+ def __len__(self):
33
+ return len(self._buffers)
34
+
35
+ def __iter__(self):
36
+ return iter(self._buffers.values())
37
+
38
+
39
+ def _create_grid_offsets(
40
+ size: List[int], stride: int, offset: float, target_device_tensor: torch.Tensor
41
+ ):
42
+ grid_height, grid_width = size
43
+ shifts_x = move_device_like(
44
+ torch.arange(offset * stride, grid_width * stride, step=stride, dtype=torch.float32),
45
+ target_device_tensor,
46
+ )
47
+ shifts_y = move_device_like(
48
+ torch.arange(offset * stride, grid_height * stride, step=stride, dtype=torch.float32),
49
+ target_device_tensor,
50
+ )
51
+
52
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
53
+ shift_x = shift_x.reshape(-1)
54
+ shift_y = shift_y.reshape(-1)
55
+ return shift_x, shift_y
56
+
57
+
58
+ def _broadcast_params(params, num_features, name):
59
+ """
60
+ If one size (or aspect ratio) is specified and there are multiple feature
61
+ maps, we "broadcast" anchors of that single size (or aspect ratio)
62
+ over all feature maps.
63
+
64
+ If params is list[float], or list[list[float]] with len(params) == 1, repeat
65
+ it num_features time.
66
+
67
+ Returns:
68
+ list[list[float]]: param for each feature
69
+ """
70
+ assert isinstance(
71
+ params, collections.abc.Sequence
72
+ ), f"{name} in anchor generator has to be a list! Got {params}."
73
+ assert len(params), f"{name} in anchor generator cannot be empty!"
74
+ if not isinstance(params[0], collections.abc.Sequence): # params is list[float]
75
+ return [params] * num_features
76
+ if len(params) == 1:
77
+ return list(params) * num_features
78
+ assert len(params) == num_features, (
79
+ f"Got {name} of length {len(params)} in anchor generator, "
80
+ f"but the number of input features is {num_features}!"
81
+ )
82
+ return params
83
+
84
+
85
+ @ANCHOR_GENERATOR_REGISTRY.register()
86
+ class DefaultAnchorGenerator(nn.Module):
87
+ """
88
+ Compute anchors in the standard ways described in
89
+ "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks".
90
+ """
91
+
92
+ box_dim: torch.jit.Final[int] = 4
93
+ """
94
+ the dimension of each anchor box.
95
+ """
96
+
97
+ @configurable
98
+ def __init__(self, *, sizes, aspect_ratios, strides, offset=0.5):
99
+ """
100
+ This interface is experimental.
101
+
102
+ Args:
103
+ sizes (list[list[float]] or list[float]):
104
+ If ``sizes`` is list[list[float]], ``sizes[i]`` is the list of anchor sizes
105
+ (i.e. sqrt of anchor area) to use for the i-th feature map.
106
+ If ``sizes`` is list[float], ``sizes`` is used for all feature maps.
107
+ Anchor sizes are given in absolute lengths in units of
108
+ the input image; they do not dynamically scale if the input image size changes.
109
+ aspect_ratios (list[list[float]] or list[float]): list of aspect ratios
110
+ (i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies.
111
+ strides (list[int]): stride of each input feature.
112
+ offset (float): Relative offset between the center of the first anchor and the top-left
113
+ corner of the image. Value has to be in [0, 1).
114
+ Recommend to use 0.5, which means half stride.
115
+ """
116
+ super().__init__()
117
+
118
+ self.strides = strides
119
+ self.num_features = len(self.strides)
120
+ sizes = _broadcast_params(sizes, self.num_features, "sizes")
121
+ aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios")
122
+ self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios)
123
+
124
+ self.offset = offset
125
+ assert 0.0 <= self.offset < 1.0, self.offset
126
+
127
+ @classmethod
128
+ def from_config(cls, cfg, input_shape: List[ShapeSpec]):
129
+ return {
130
+ "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
131
+ "aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS,
132
+ "strides": [x.stride for x in input_shape],
133
+ "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
134
+ }
135
+
136
+ def _calculate_anchors(self, sizes, aspect_ratios):
137
+ cell_anchors = [
138
+ self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios)
139
+ ]
140
+ return BufferList(cell_anchors)
141
+
142
+ @property
143
+ @torch.jit.unused
144
+ def num_cell_anchors(self):
145
+ """
146
+ Alias of `num_anchors`.
147
+ """
148
+ return self.num_anchors
149
+
150
+ @property
151
+ @torch.jit.unused
152
+ def num_anchors(self):
153
+ """
154
+ Returns:
155
+ list[int]: Each int is the number of anchors at every pixel
156
+ location, on that feature map.
157
+ For example, if at every pixel we use anchors of 3 aspect
158
+ ratios and 5 sizes, the number of anchors is 15.
159
+ (See also ANCHOR_GENERATOR.SIZES and ANCHOR_GENERATOR.ASPECT_RATIOS in config)
160
+
161
+ In standard RPN models, `num_anchors` on every feature map is the same.
162
+ """
163
+ return [len(cell_anchors) for cell_anchors in self.cell_anchors]
164
+
165
+ def _grid_anchors(self, grid_sizes: List[List[int]]):
166
+ """
167
+ Returns:
168
+ list[Tensor]: #featuremap tensors, each is (#locations x #cell_anchors) x 4
169
+ """
170
+ anchors = []
171
+ # buffers() not supported by torchscript. use named_buffers() instead
172
+ buffers: List[torch.Tensor] = [x[1] for x in self.cell_anchors.named_buffers()]
173
+ for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers):
174
+ shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
175
+ shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
176
+
177
+ anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
178
+
179
+ return anchors
180
+
181
+ def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)):
182
+ """
183
+ Generate a tensor storing canonical anchor boxes, which are all anchor
184
+ boxes of different sizes and aspect_ratios centered at (0, 0).
185
+ We can later build the set of anchors for a full feature map by
186
+ shifting and tiling these tensors (see `meth:_grid_anchors`).
187
+
188
+ Args:
189
+ sizes (tuple[float]):
190
+ aspect_ratios (tuple[float]]):
191
+
192
+ Returns:
193
+ Tensor of shape (len(sizes) * len(aspect_ratios), 4) storing anchor boxes
194
+ in XYXY format.
195
+ """
196
+
197
+ # This is different from the anchor generator defined in the original Faster R-CNN
198
+ # code or Detectron. They yield the same AP, however the old version defines cell
199
+ # anchors in a less natural way with a shift relative to the feature grid and
200
+ # quantization that results in slightly different sizes for different aspect ratios.
201
+ # See also https://github.com/facebookresearch/Detectron/issues/227
202
+
203
+ anchors = []
204
+ for size in sizes:
205
+ area = size**2.0
206
+ for aspect_ratio in aspect_ratios:
207
+ # s * s = w * h
208
+ # a = h / w
209
+ # ... some algebra ...
210
+ # w = sqrt(s * s / a)
211
+ # h = a * w
212
+ w = math.sqrt(area / aspect_ratio)
213
+ h = aspect_ratio * w
214
+ x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
215
+ anchors.append([x0, y0, x1, y1])
216
+ return torch.tensor(anchors)
217
+
218
+ def forward(self, features: List[torch.Tensor]):
219
+ """
220
+ Args:
221
+ features (list[Tensor]): list of backbone feature maps on which to generate anchors.
222
+
223
+ Returns:
224
+ list[Boxes]: a list of Boxes containing all the anchors for each feature map
225
+ (i.e. the cell anchors repeated over all locations in the feature map).
226
+ The number of anchors of each feature map is Hi x Wi x num_cell_anchors,
227
+ where Hi, Wi are resolution of the feature map divided by anchor stride.
228
+ """
229
+ grid_sizes = [feature_map.shape[-2:] for feature_map in features]
230
+ anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) # pyre-ignore
231
+ return [Boxes(x) for x in anchors_over_all_feature_maps]
232
+
233
+
234
+ @ANCHOR_GENERATOR_REGISTRY.register()
235
+ class RotatedAnchorGenerator(nn.Module):
236
+ """
237
+ Compute rotated anchors used by Rotated RPN (RRPN), described in
238
+ "Arbitrary-Oriented Scene Text Detection via Rotation Proposals".
239
+ """
240
+
241
+ box_dim: int = 5
242
+ """
243
+ the dimension of each anchor box.
244
+ """
245
+
246
+ @configurable
247
+ def __init__(self, *, sizes, aspect_ratios, strides, angles, offset=0.5):
248
+ """
249
+ This interface is experimental.
250
+
251
+ Args:
252
+ sizes (list[list[float]] or list[float]):
253
+ If sizes is list[list[float]], sizes[i] is the list of anchor sizes
254
+ (i.e. sqrt of anchor area) to use for the i-th feature map.
255
+ If sizes is list[float], the sizes are used for all feature maps.
256
+ Anchor sizes are given in absolute lengths in units of
257
+ the input image; they do not dynamically scale if the input image size changes.
258
+ aspect_ratios (list[list[float]] or list[float]): list of aspect ratios
259
+ (i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies.
260
+ strides (list[int]): stride of each input feature.
261
+ angles (list[list[float]] or list[float]): list of angles (in degrees CCW)
262
+ to use for anchors. Same "broadcast" rule for `sizes` applies.
263
+ offset (float): Relative offset between the center of the first anchor and the top-left
264
+ corner of the image. Value has to be in [0, 1).
265
+ Recommend to use 0.5, which means half stride.
266
+ """
267
+ super().__init__()
268
+
269
+ self.strides = strides
270
+ self.num_features = len(self.strides)
271
+ sizes = _broadcast_params(sizes, self.num_features, "sizes")
272
+ aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios")
273
+ angles = _broadcast_params(angles, self.num_features, "angles")
274
+ self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios, angles)
275
+
276
+ self.offset = offset
277
+ assert 0.0 <= self.offset < 1.0, self.offset
278
+
279
+ @classmethod
280
+ def from_config(cls, cfg, input_shape: List[ShapeSpec]):
281
+ return {
282
+ "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
283
+ "aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS,
284
+ "strides": [x.stride for x in input_shape],
285
+ "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
286
+ "angles": cfg.MODEL.ANCHOR_GENERATOR.ANGLES,
287
+ }
288
+
289
+ def _calculate_anchors(self, sizes, aspect_ratios, angles):
290
+ cell_anchors = [
291
+ self.generate_cell_anchors(size, aspect_ratio, angle).float()
292
+ for size, aspect_ratio, angle in zip(sizes, aspect_ratios, angles)
293
+ ]
294
+ return BufferList(cell_anchors)
295
+
296
+ @property
297
+ def num_cell_anchors(self):
298
+ """
299
+ Alias of `num_anchors`.
300
+ """
301
+ return self.num_anchors
302
+
303
+ @property
304
+ def num_anchors(self):
305
+ """
306
+ Returns:
307
+ list[int]: Each int is the number of anchors at every pixel
308
+ location, on that feature map.
309
+ For example, if at every pixel we use anchors of 3 aspect
310
+ ratios, 2 sizes and 5 angles, the number of anchors is 30.
311
+ (See also ANCHOR_GENERATOR.SIZES, ANCHOR_GENERATOR.ASPECT_RATIOS
312
+ and ANCHOR_GENERATOR.ANGLES in config)
313
+
314
+ In standard RRPN models, `num_anchors` on every feature map is the same.
315
+ """
316
+ return [len(cell_anchors) for cell_anchors in self.cell_anchors]
317
+
318
+ def _grid_anchors(self, grid_sizes: List[List[int]]):
319
+ anchors = []
320
+ for size, stride, base_anchors in zip(
321
+ grid_sizes,
322
+ self.strides,
323
+ self.cell_anchors._buffers.values(),
324
+ ):
325
+ shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
326
+ zeros = torch.zeros_like(shift_x)
327
+ shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1)
328
+
329
+ anchors.append((shifts.view(-1, 1, 5) + base_anchors.view(1, -1, 5)).reshape(-1, 5))
330
+
331
+ return anchors
332
+
333
+ def generate_cell_anchors(
334
+ self,
335
+ sizes=(32, 64, 128, 256, 512),
336
+ aspect_ratios=(0.5, 1, 2),
337
+ angles=(-90, -60, -30, 0, 30, 60, 90),
338
+ ):
339
+ """
340
+ Generate a tensor storing canonical anchor boxes, which are all anchor
341
+ boxes of different sizes, aspect_ratios, angles centered at (0, 0).
342
+ We can later build the set of anchors for a full feature map by
343
+ shifting and tiling these tensors (see `meth:_grid_anchors`).
344
+
345
+ Args:
346
+ sizes (tuple[float]):
347
+ aspect_ratios (tuple[float]]):
348
+ angles (tuple[float]]):
349
+
350
+ Returns:
351
+ Tensor of shape (len(sizes) * len(aspect_ratios) * len(angles), 5)
352
+ storing anchor boxes in (x_ctr, y_ctr, w, h, angle) format.
353
+ """
354
+ anchors = []
355
+ for size in sizes:
356
+ area = size**2.0
357
+ for aspect_ratio in aspect_ratios:
358
+ # s * s = w * h
359
+ # a = h / w
360
+ # ... some algebra ...
361
+ # w = sqrt(s * s / a)
362
+ # h = a * w
363
+ w = math.sqrt(area / aspect_ratio)
364
+ h = aspect_ratio * w
365
+ anchors.extend([0, 0, w, h, a] for a in angles)
366
+
367
+ return torch.tensor(anchors)
368
+
369
+ def forward(self, features):
370
+ """
371
+ Args:
372
+ features (list[Tensor]): list of backbone feature maps on which to generate anchors.
373
+
374
+ Returns:
375
+ list[RotatedBoxes]: a list of Boxes containing all the anchors for each feature map
376
+ (i.e. the cell anchors repeated over all locations in the feature map).
377
+ The number of anchors of each feature map is Hi x Wi x num_cell_anchors,
378
+ where Hi, Wi are resolution of the feature map divided by anchor stride.
379
+ """
380
+ grid_sizes = [feature_map.shape[-2:] for feature_map in features]
381
+ anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
382
+ return [RotatedBoxes(x) for x in anchors_over_all_feature_maps]
383
+
384
+
385
+ def build_anchor_generator(cfg, input_shape):
386
+ """
387
+ Built an anchor generator from `cfg.MODEL.ANCHOR_GENERATOR.NAME`.
388
+ """
389
+ anchor_generator = cfg.MODEL.ANCHOR_GENERATOR.NAME
390
+ return ANCHOR_GENERATOR_REGISTRY.get(anchor_generator)(cfg, input_shape)
CatVTON/detectron2/modeling/box_regression.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import math
3
+ from typing import List, Tuple, Union
4
+ import torch
5
+ from fvcore.nn import giou_loss, smooth_l1_loss
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.layers import cat, ciou_loss, diou_loss
9
+ from detectron2.structures import Boxes
10
+
11
+ # Value for clamping large dw and dh predictions. The heuristic is that we clamp
12
+ # such that dw and dh are no larger than what would transform a 16px box into a
13
+ # 1000px box (based on a small anchor, 16px, and a typical image size, 1000px).
14
+ _DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
15
+
16
+
17
+ __all__ = ["Box2BoxTransform", "Box2BoxTransformRotated", "Box2BoxTransformLinear"]
18
+
19
+
20
+ @torch.jit.script
21
+ class Box2BoxTransform:
22
+ """
23
+ The box-to-box transform defined in R-CNN. The transformation is parameterized
24
+ by 4 deltas: (dx, dy, dw, dh). The transformation scales the box's width and height
25
+ by exp(dw), exp(dh) and shifts a box's center by the offset (dx * width, dy * height).
26
+ """
27
+
28
+ def __init__(
29
+ self, weights: Tuple[float, float, float, float], scale_clamp: float = _DEFAULT_SCALE_CLAMP
30
+ ):
31
+ """
32
+ Args:
33
+ weights (4-element tuple): Scaling factors that are applied to the
34
+ (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set
35
+ such that the deltas have unit variance; now they are treated as
36
+ hyperparameters of the system.
37
+ scale_clamp (float): When predicting deltas, the predicted box scaling
38
+ factors (dw and dh) are clamped such that they are <= scale_clamp.
39
+ """
40
+ self.weights = weights
41
+ self.scale_clamp = scale_clamp
42
+
43
+ def get_deltas(self, src_boxes, target_boxes):
44
+ """
45
+ Get box regression transformation deltas (dx, dy, dw, dh) that can be used
46
+ to transform the `src_boxes` into the `target_boxes`. That is, the relation
47
+ ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
48
+ any delta is too large and is clamped).
49
+
50
+ Args:
51
+ src_boxes (Tensor): source boxes, e.g., object proposals
52
+ target_boxes (Tensor): target of the transformation, e.g., ground-truth
53
+ boxes.
54
+ """
55
+ assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
56
+ assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
57
+
58
+ src_widths = src_boxes[:, 2] - src_boxes[:, 0]
59
+ src_heights = src_boxes[:, 3] - src_boxes[:, 1]
60
+ src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths
61
+ src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights
62
+
63
+ target_widths = target_boxes[:, 2] - target_boxes[:, 0]
64
+ target_heights = target_boxes[:, 3] - target_boxes[:, 1]
65
+ target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths
66
+ target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights
67
+
68
+ wx, wy, ww, wh = self.weights
69
+ dx = wx * (target_ctr_x - src_ctr_x) / src_widths
70
+ dy = wy * (target_ctr_y - src_ctr_y) / src_heights
71
+ dw = ww * torch.log(target_widths / src_widths)
72
+ dh = wh * torch.log(target_heights / src_heights)
73
+
74
+ deltas = torch.stack((dx, dy, dw, dh), dim=1)
75
+ assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!"
76
+ return deltas
77
+
78
+ def apply_deltas(self, deltas, boxes):
79
+ """
80
+ Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
81
+
82
+ Args:
83
+ deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
84
+ deltas[i] represents k potentially different class-specific
85
+ box transformations for the single box boxes[i].
86
+ boxes (Tensor): boxes to transform, of shape (N, 4)
87
+ """
88
+ deltas = deltas.float() # ensure fp32 for decoding precision
89
+ boxes = boxes.to(deltas.dtype)
90
+
91
+ widths = boxes[:, 2] - boxes[:, 0]
92
+ heights = boxes[:, 3] - boxes[:, 1]
93
+ ctr_x = boxes[:, 0] + 0.5 * widths
94
+ ctr_y = boxes[:, 1] + 0.5 * heights
95
+
96
+ wx, wy, ww, wh = self.weights
97
+ dx = deltas[:, 0::4] / wx
98
+ dy = deltas[:, 1::4] / wy
99
+ dw = deltas[:, 2::4] / ww
100
+ dh = deltas[:, 3::4] / wh
101
+
102
+ # Prevent sending too large values into torch.exp()
103
+ dw = torch.clamp(dw, max=self.scale_clamp)
104
+ dh = torch.clamp(dh, max=self.scale_clamp)
105
+
106
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
107
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
108
+ pred_w = torch.exp(dw) * widths[:, None]
109
+ pred_h = torch.exp(dh) * heights[:, None]
110
+
111
+ x1 = pred_ctr_x - 0.5 * pred_w
112
+ y1 = pred_ctr_y - 0.5 * pred_h
113
+ x2 = pred_ctr_x + 0.5 * pred_w
114
+ y2 = pred_ctr_y + 0.5 * pred_h
115
+ pred_boxes = torch.stack((x1, y1, x2, y2), dim=-1)
116
+ return pred_boxes.reshape(deltas.shape)
117
+
118
+
119
+ @torch.jit.script
120
+ class Box2BoxTransformRotated:
121
+ """
122
+ The box-to-box transform defined in Rotated R-CNN. The transformation is parameterized
123
+ by 5 deltas: (dx, dy, dw, dh, da). The transformation scales the box's width and height
124
+ by exp(dw), exp(dh), shifts a box's center by the offset (dx * width, dy * height),
125
+ and rotate a box's angle by da (radians).
126
+ Note: angles of deltas are in radians while angles of boxes are in degrees.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ weights: Tuple[float, float, float, float, float],
132
+ scale_clamp: float = _DEFAULT_SCALE_CLAMP,
133
+ ):
134
+ """
135
+ Args:
136
+ weights (5-element tuple): Scaling factors that are applied to the
137
+ (dx, dy, dw, dh, da) deltas. These are treated as
138
+ hyperparameters of the system.
139
+ scale_clamp (float): When predicting deltas, the predicted box scaling
140
+ factors (dw and dh) are clamped such that they are <= scale_clamp.
141
+ """
142
+ self.weights = weights
143
+ self.scale_clamp = scale_clamp
144
+
145
+ def get_deltas(self, src_boxes, target_boxes):
146
+ """
147
+ Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used
148
+ to transform the `src_boxes` into the `target_boxes`. That is, the relation
149
+ ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
150
+ any delta is too large and is clamped).
151
+
152
+ Args:
153
+ src_boxes (Tensor): Nx5 source boxes, e.g., object proposals
154
+ target_boxes (Tensor): Nx5 target of the transformation, e.g., ground-truth
155
+ boxes.
156
+ """
157
+ assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
158
+ assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
159
+
160
+ src_ctr_x, src_ctr_y, src_widths, src_heights, src_angles = torch.unbind(src_boxes, dim=1)
161
+
162
+ target_ctr_x, target_ctr_y, target_widths, target_heights, target_angles = torch.unbind(
163
+ target_boxes, dim=1
164
+ )
165
+
166
+ wx, wy, ww, wh, wa = self.weights
167
+ dx = wx * (target_ctr_x - src_ctr_x) / src_widths
168
+ dy = wy * (target_ctr_y - src_ctr_y) / src_heights
169
+ dw = ww * torch.log(target_widths / src_widths)
170
+ dh = wh * torch.log(target_heights / src_heights)
171
+ # Angles of deltas are in radians while angles of boxes are in degrees.
172
+ # the conversion to radians serve as a way to normalize the values
173
+ da = target_angles - src_angles
174
+ da = (da + 180.0) % 360.0 - 180.0 # make it in [-180, 180)
175
+ da *= wa * math.pi / 180.0
176
+
177
+ deltas = torch.stack((dx, dy, dw, dh, da), dim=1)
178
+ assert (
179
+ (src_widths > 0).all().item()
180
+ ), "Input boxes to Box2BoxTransformRotated are not valid!"
181
+ return deltas
182
+
183
+ def apply_deltas(self, deltas, boxes):
184
+ """
185
+ Apply transformation `deltas` (dx, dy, dw, dh, da) to `boxes`.
186
+
187
+ Args:
188
+ deltas (Tensor): transformation deltas of shape (N, k*5).
189
+ deltas[i] represents box transformation for the single box boxes[i].
190
+ boxes (Tensor): boxes to transform, of shape (N, 5)
191
+ """
192
+ assert deltas.shape[1] % 5 == 0 and boxes.shape[1] == 5
193
+
194
+ boxes = boxes.to(deltas.dtype).unsqueeze(2)
195
+
196
+ ctr_x = boxes[:, 0]
197
+ ctr_y = boxes[:, 1]
198
+ widths = boxes[:, 2]
199
+ heights = boxes[:, 3]
200
+ angles = boxes[:, 4]
201
+
202
+ wx, wy, ww, wh, wa = self.weights
203
+
204
+ dx = deltas[:, 0::5] / wx
205
+ dy = deltas[:, 1::5] / wy
206
+ dw = deltas[:, 2::5] / ww
207
+ dh = deltas[:, 3::5] / wh
208
+ da = deltas[:, 4::5] / wa
209
+
210
+ # Prevent sending too large values into torch.exp()
211
+ dw = torch.clamp(dw, max=self.scale_clamp)
212
+ dh = torch.clamp(dh, max=self.scale_clamp)
213
+
214
+ pred_boxes = torch.zeros_like(deltas)
215
+ pred_boxes[:, 0::5] = dx * widths + ctr_x # x_ctr
216
+ pred_boxes[:, 1::5] = dy * heights + ctr_y # y_ctr
217
+ pred_boxes[:, 2::5] = torch.exp(dw) * widths # width
218
+ pred_boxes[:, 3::5] = torch.exp(dh) * heights # height
219
+
220
+ # Following original RRPN implementation,
221
+ # angles of deltas are in radians while angles of boxes are in degrees.
222
+ pred_angle = da * 180.0 / math.pi + angles
223
+ pred_angle = (pred_angle + 180.0) % 360.0 - 180.0 # make it in [-180, 180)
224
+
225
+ pred_boxes[:, 4::5] = pred_angle
226
+
227
+ return pred_boxes
228
+
229
+
230
+ class Box2BoxTransformLinear:
231
+ """
232
+ The linear box-to-box transform defined in FCOS. The transformation is parameterized
233
+ by the distance from the center of (square) src box to 4 edges of the target box.
234
+ """
235
+
236
+ def __init__(self, normalize_by_size=True):
237
+ """
238
+ Args:
239
+ normalize_by_size: normalize deltas by the size of src (anchor) boxes.
240
+ """
241
+ self.normalize_by_size = normalize_by_size
242
+
243
+ def get_deltas(self, src_boxes, target_boxes):
244
+ """
245
+ Get box regression transformation deltas (dx1, dy1, dx2, dy2) that can be used
246
+ to transform the `src_boxes` into the `target_boxes`. That is, the relation
247
+ ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true.
248
+ The center of src must be inside target boxes.
249
+
250
+ Args:
251
+ src_boxes (Tensor): square source boxes, e.g., anchors
252
+ target_boxes (Tensor): target of the transformation, e.g., ground-truth
253
+ boxes.
254
+ """
255
+ assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
256
+ assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
257
+
258
+ src_ctr_x = 0.5 * (src_boxes[:, 0] + src_boxes[:, 2])
259
+ src_ctr_y = 0.5 * (src_boxes[:, 1] + src_boxes[:, 3])
260
+
261
+ target_l = src_ctr_x - target_boxes[:, 0]
262
+ target_t = src_ctr_y - target_boxes[:, 1]
263
+ target_r = target_boxes[:, 2] - src_ctr_x
264
+ target_b = target_boxes[:, 3] - src_ctr_y
265
+
266
+ deltas = torch.stack((target_l, target_t, target_r, target_b), dim=1)
267
+ if self.normalize_by_size:
268
+ stride_w = src_boxes[:, 2] - src_boxes[:, 0]
269
+ stride_h = src_boxes[:, 3] - src_boxes[:, 1]
270
+ strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
271
+ deltas = deltas / strides
272
+
273
+ return deltas
274
+
275
+ def apply_deltas(self, deltas, boxes):
276
+ """
277
+ Apply transformation `deltas` (dx1, dy1, dx2, dy2) to `boxes`.
278
+
279
+ Args:
280
+ deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
281
+ deltas[i] represents k potentially different class-specific
282
+ box transformations for the single box boxes[i].
283
+ boxes (Tensor): boxes to transform, of shape (N, 4)
284
+ """
285
+ # Ensure the output is a valid box. See Sec 2.1 of https://arxiv.org/abs/2006.09214
286
+ deltas = F.relu(deltas)
287
+ boxes = boxes.to(deltas.dtype)
288
+
289
+ ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
290
+ ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
291
+ if self.normalize_by_size:
292
+ stride_w = boxes[:, 2] - boxes[:, 0]
293
+ stride_h = boxes[:, 3] - boxes[:, 1]
294
+ strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
295
+ deltas = deltas * strides
296
+
297
+ l = deltas[:, 0::4]
298
+ t = deltas[:, 1::4]
299
+ r = deltas[:, 2::4]
300
+ b = deltas[:, 3::4]
301
+
302
+ pred_boxes = torch.zeros_like(deltas)
303
+ pred_boxes[:, 0::4] = ctr_x[:, None] - l # x1
304
+ pred_boxes[:, 1::4] = ctr_y[:, None] - t # y1
305
+ pred_boxes[:, 2::4] = ctr_x[:, None] + r # x2
306
+ pred_boxes[:, 3::4] = ctr_y[:, None] + b # y2
307
+ return pred_boxes
308
+
309
+
310
+ def _dense_box_regression_loss(
311
+ anchors: List[Union[Boxes, torch.Tensor]],
312
+ box2box_transform: Box2BoxTransform,
313
+ pred_anchor_deltas: List[torch.Tensor],
314
+ gt_boxes: List[torch.Tensor],
315
+ fg_mask: torch.Tensor,
316
+ box_reg_loss_type="smooth_l1",
317
+ smooth_l1_beta=0.0,
318
+ ):
319
+ """
320
+ Compute loss for dense multi-level box regression.
321
+ Loss is accumulated over ``fg_mask``.
322
+
323
+ Args:
324
+ anchors: #lvl anchor boxes, each is (HixWixA, 4)
325
+ pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4)
326
+ gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A))
327
+ fg_mask: the foreground boolean mask of shape (N, R) to compute loss on
328
+ box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou",
329
+ "diou", "ciou".
330
+ smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to
331
+ use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1"
332
+ """
333
+ if isinstance(anchors[0], Boxes):
334
+ anchors = type(anchors[0]).cat(anchors).tensor # (R, 4)
335
+ else:
336
+ anchors = cat(anchors)
337
+ if box_reg_loss_type == "smooth_l1":
338
+ gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
339
+ gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4)
340
+ loss_box_reg = smooth_l1_loss(
341
+ cat(pred_anchor_deltas, dim=1)[fg_mask],
342
+ gt_anchor_deltas[fg_mask],
343
+ beta=smooth_l1_beta,
344
+ reduction="sum",
345
+ )
346
+ elif box_reg_loss_type == "giou":
347
+ pred_boxes = [
348
+ box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
349
+ ]
350
+ loss_box_reg = giou_loss(
351
+ torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
352
+ )
353
+ elif box_reg_loss_type == "diou":
354
+ pred_boxes = [
355
+ box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
356
+ ]
357
+ loss_box_reg = diou_loss(
358
+ torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
359
+ )
360
+ elif box_reg_loss_type == "ciou":
361
+ pred_boxes = [
362
+ box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
363
+ ]
364
+ loss_box_reg = ciou_loss(
365
+ torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
366
+ )
367
+ else:
368
+ raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'")
369
+ return loss_box_reg
CatVTON/detectron2/modeling/matcher.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import List
3
+ import torch
4
+
5
+ from detectron2.layers import nonzero_tuple
6
+
7
+
8
+ # TODO: the name is too general
9
+ class Matcher:
10
+ """
11
+ This class assigns to each predicted "element" (e.g., a box) a ground-truth
12
+ element. Each predicted element will have exactly zero or one matches; each
13
+ ground-truth element may be matched to zero or more predicted elements.
14
+
15
+ The matching is determined by the MxN match_quality_matrix, that characterizes
16
+ how well each (ground-truth, prediction)-pair match each other. For example,
17
+ if the elements are boxes, this matrix may contain box intersection-over-union
18
+ overlap values.
19
+
20
+ The matcher returns (a) a vector of length N containing the index of the
21
+ ground-truth element m in [0, M) that matches to prediction n in [0, N).
22
+ (b) a vector of length N containing the labels for each prediction.
23
+ """
24
+
25
+ def __init__(
26
+ self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
27
+ ):
28
+ """
29
+ Args:
30
+ thresholds (list): a list of thresholds used to stratify predictions
31
+ into levels.
32
+ labels (list): a list of values to label predictions belonging at
33
+ each level. A label can be one of {-1, 0, 1} signifying
34
+ {ignore, negative class, positive class}, respectively.
35
+ allow_low_quality_matches (bool): if True, produce additional matches
36
+ for predictions with maximum match quality lower than high_threshold.
37
+ See set_low_quality_matches_ for more details.
38
+
39
+ For example,
40
+ thresholds = [0.3, 0.5]
41
+ labels = [0, -1, 1]
42
+ All predictions with iou < 0.3 will be marked with 0 and
43
+ thus will be considered as false positives while training.
44
+ All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
45
+ thus will be ignored.
46
+ All predictions with 0.5 <= iou will be marked with 1 and
47
+ thus will be considered as true positives.
48
+ """
49
+ # Add -inf and +inf to first and last position in thresholds
50
+ thresholds = thresholds[:]
51
+ assert thresholds[0] > 0
52
+ thresholds.insert(0, -float("inf"))
53
+ thresholds.append(float("inf"))
54
+ # Currently torchscript does not support all + generator
55
+ assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])])
56
+ assert all([l in [-1, 0, 1] for l in labels])
57
+ assert len(labels) == len(thresholds) - 1
58
+ self.thresholds = thresholds
59
+ self.labels = labels
60
+ self.allow_low_quality_matches = allow_low_quality_matches
61
+
62
+ def __call__(self, match_quality_matrix):
63
+ """
64
+ Args:
65
+ match_quality_matrix (Tensor[float]): an MxN tensor, containing the
66
+ pairwise quality between M ground-truth elements and N predicted
67
+ elements. All elements must be >= 0 (due to the us of `torch.nonzero`
68
+ for selecting indices in :meth:`set_low_quality_matches_`).
69
+
70
+ Returns:
71
+ matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
72
+ ground-truth index in [0, M)
73
+ match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
74
+ whether a prediction is a true or false positive or ignored
75
+ """
76
+ assert match_quality_matrix.dim() == 2
77
+ if match_quality_matrix.numel() == 0:
78
+ default_matches = match_quality_matrix.new_full(
79
+ (match_quality_matrix.size(1),), 0, dtype=torch.int64
80
+ )
81
+ # When no gt boxes exist, we define IOU = 0 and therefore set labels
82
+ # to `self.labels[0]`, which usually defaults to background class 0
83
+ # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
84
+ default_match_labels = match_quality_matrix.new_full(
85
+ (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
86
+ )
87
+ return default_matches, default_match_labels
88
+
89
+ assert torch.all(match_quality_matrix >= 0)
90
+
91
+ # match_quality_matrix is M (gt) x N (predicted)
92
+ # Max over gt elements (dim 0) to find best gt candidate for each prediction
93
+ matched_vals, matches = match_quality_matrix.max(dim=0)
94
+
95
+ match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
96
+
97
+ for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
98
+ low_high = (matched_vals >= low) & (matched_vals < high)
99
+ match_labels[low_high] = l
100
+
101
+ if self.allow_low_quality_matches:
102
+ self.set_low_quality_matches_(match_labels, match_quality_matrix)
103
+
104
+ return matches, match_labels
105
+
106
+ def set_low_quality_matches_(self, match_labels, match_quality_matrix):
107
+ """
108
+ Produce additional matches for predictions that have only low-quality matches.
109
+ Specifically, for each ground-truth G find the set of predictions that have
110
+ maximum overlap with it (including ties); for each prediction in that set, if
111
+ it is unmatched, then match it to the ground-truth G.
112
+
113
+ This function implements the RPN assignment case (i) in Sec. 3.1.2 of
114
+ :paper:`Faster R-CNN`.
115
+ """
116
+ # For each gt, find the prediction with which it has highest quality
117
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
118
+ # Find the highest quality match available, even if it is low, including ties.
119
+ # Note that the matches qualities must be positive due to the use of
120
+ # `torch.nonzero`.
121
+ _, pred_inds_with_highest_quality = nonzero_tuple(
122
+ match_quality_matrix == highest_quality_foreach_gt[:, None]
123
+ )
124
+ # If an anchor was labeled positive only due to a low-quality match
125
+ # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B.
126
+ # This follows the implementation in Detectron, and is found to have no significant impact.
127
+ match_labels[pred_inds_with_highest_quality] = 1
CatVTON/detectron2/modeling/poolers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import math
3
+ from typing import List, Optional
4
+ import torch
5
+ from torch import nn
6
+ from torchvision.ops import RoIPool
7
+
8
+ from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
9
+ from detectron2.structures import Boxes
10
+ from detectron2.utils.tracing import assert_fx_safe, is_fx_tracing
11
+
12
+ """
13
+ To export ROIPooler to torchscript, in this file, variables that should be annotated with
14
+ `Union[List[Boxes], List[RotatedBoxes]]` are only annotated with `List[Boxes]`.
15
+
16
+ TODO: Correct these annotations when torchscript support `Union`.
17
+ https://github.com/pytorch/pytorch/issues/41412
18
+ """
19
+
20
+ __all__ = ["ROIPooler"]
21
+
22
+
23
+ def assign_boxes_to_levels(
24
+ box_lists: List[Boxes],
25
+ min_level: int,
26
+ max_level: int,
27
+ canonical_box_size: int,
28
+ canonical_level: int,
29
+ ):
30
+ """
31
+ Map each box in `box_lists` to a feature map level index and return the assignment
32
+ vector.
33
+
34
+ Args:
35
+ box_lists (list[Boxes] | list[RotatedBoxes]): A list of N Boxes or N RotatedBoxes,
36
+ where N is the number of images in the batch.
37
+ min_level (int): Smallest feature map level index. The input is considered index 0,
38
+ the output of stage 1 is index 1, and so.
39
+ max_level (int): Largest feature map level index.
40
+ canonical_box_size (int): A canonical box size in pixels (sqrt(box area)).
41
+ canonical_level (int): The feature map level index on which a canonically-sized box
42
+ should be placed.
43
+
44
+ Returns:
45
+ A tensor of length M, where M is the total number of boxes aggregated over all
46
+ N batch images. The memory layout corresponds to the concatenation of boxes
47
+ from all images. Each element is the feature map index, as an offset from
48
+ `self.min_level`, for the corresponding box (so value i means the box is at
49
+ `self.min_level + i`).
50
+ """
51
+ box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists]))
52
+ # Eqn.(1) in FPN paper
53
+ level_assignments = torch.floor(
54
+ canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
55
+ )
56
+ # clamp level to (min, max), in case the box size is too large or too small
57
+ # for the available feature maps
58
+ level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
59
+ return level_assignments.to(torch.int64) - min_level
60
+
61
+
62
+ # script the module to avoid hardcoded device type
63
+ @torch.jit.script_if_tracing
64
+ def _convert_boxes_to_pooler_format(boxes: torch.Tensor, sizes: torch.Tensor) -> torch.Tensor:
65
+ sizes = sizes.to(device=boxes.device)
66
+ indices = torch.repeat_interleave(
67
+ torch.arange(len(sizes), dtype=boxes.dtype, device=boxes.device), sizes
68
+ )
69
+ return cat([indices[:, None], boxes], dim=1)
70
+
71
+
72
+ def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
73
+ """
74
+ Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops
75
+ (see description under Returns).
76
+
77
+ Args:
78
+ box_lists (list[Boxes] | list[RotatedBoxes]):
79
+ A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch.
80
+
81
+ Returns:
82
+ When input is list[Boxes]:
83
+ A tensor of shape (M, 5), where M is the total number of boxes aggregated over all
84
+ N batch images.
85
+ The 5 columns are (batch index, x0, y0, x1, y1), where batch index
86
+ is the index in [0, N) identifying which batch image the box with corners at
87
+ (x0, y0, x1, y1) comes from.
88
+ When input is list[RotatedBoxes]:
89
+ A tensor of shape (M, 6), where M is the total number of boxes aggregated over all
90
+ N batch images.
91
+ The 6 columns are (batch index, x_ctr, y_ctr, width, height, angle_degrees),
92
+ where batch index is the index in [0, N) identifying which batch image the
93
+ rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from.
94
+ """
95
+ boxes = torch.cat([x.tensor for x in box_lists], dim=0)
96
+ # __len__ returns Tensor in tracing.
97
+ sizes = shapes_to_tensor([x.__len__() for x in box_lists])
98
+ return _convert_boxes_to_pooler_format(boxes, sizes)
99
+
100
+
101
+ @torch.jit.script_if_tracing
102
+ def _create_zeros(
103
+ batch_target: Optional[torch.Tensor],
104
+ channels: int,
105
+ height: int,
106
+ width: int,
107
+ like_tensor: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ batches = batch_target.shape[0] if batch_target is not None else 0
110
+ sizes = (batches, channels, height, width)
111
+ return torch.zeros(sizes, dtype=like_tensor.dtype, device=like_tensor.device)
112
+
113
+
114
+ class ROIPooler(nn.Module):
115
+ """
116
+ Region of interest feature map pooler that supports pooling from one or more
117
+ feature maps.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ output_size,
123
+ scales,
124
+ sampling_ratio,
125
+ pooler_type,
126
+ canonical_box_size=224,
127
+ canonical_level=4,
128
+ ):
129
+ """
130
+ Args:
131
+ output_size (int, tuple[int] or list[int]): output size of the pooled region,
132
+ e.g., 14 x 14. If tuple or list is given, the length must be 2.
133
+ scales (list[float]): The scale for each low-level pooling op relative to
134
+ the input image. For a feature map with stride s relative to the input
135
+ image, scale is defined as 1/s. The stride must be power of 2.
136
+ When there are multiple scales, they must form a pyramid, i.e. they must be
137
+ a monotically decreasing geometric sequence with a factor of 1/2.
138
+ sampling_ratio (int): The `sampling_ratio` parameter for the ROIAlign op.
139
+ pooler_type (string): Name of the type of pooling operation that should be applied.
140
+ For instance, "ROIPool" or "ROIAlignV2".
141
+ canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). The default
142
+ is heuristically defined as 224 pixels in the FPN paper (based on ImageNet
143
+ pre-training).
144
+ canonical_level (int): The feature map level index from which a canonically-sized box
145
+ should be placed. The default is defined as level 4 (stride=16) in the FPN paper,
146
+ i.e., a box of size 224x224 will be placed on the feature with stride=16.
147
+ The box placement for all boxes will be determined from their sizes w.r.t
148
+ canonical_box_size. For example, a box whose area is 4x that of a canonical box
149
+ should be used to pool features from feature level ``canonical_level+1``.
150
+
151
+ Note that the actual input feature maps given to this module may not have
152
+ sufficiently many levels for the input boxes. If the boxes are too large or too
153
+ small for the input feature maps, the closest level will be used.
154
+ """
155
+ super().__init__()
156
+
157
+ if isinstance(output_size, int):
158
+ output_size = (output_size, output_size)
159
+ assert len(output_size) == 2
160
+ assert isinstance(output_size[0], int) and isinstance(output_size[1], int)
161
+ self.output_size = output_size
162
+
163
+ if pooler_type == "ROIAlign":
164
+ self.level_poolers = nn.ModuleList(
165
+ ROIAlign(
166
+ output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=False
167
+ )
168
+ for scale in scales
169
+ )
170
+ elif pooler_type == "ROIAlignV2":
171
+ self.level_poolers = nn.ModuleList(
172
+ ROIAlign(
173
+ output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True
174
+ )
175
+ for scale in scales
176
+ )
177
+ elif pooler_type == "ROIPool":
178
+ self.level_poolers = nn.ModuleList(
179
+ RoIPool(output_size, spatial_scale=scale) for scale in scales
180
+ )
181
+ elif pooler_type == "ROIAlignRotated":
182
+ self.level_poolers = nn.ModuleList(
183
+ ROIAlignRotated(output_size, spatial_scale=scale, sampling_ratio=sampling_ratio)
184
+ for scale in scales
185
+ )
186
+ else:
187
+ raise ValueError("Unknown pooler type: {}".format(pooler_type))
188
+
189
+ # Map scale (defined as 1 / stride) to its feature map level under the
190
+ # assumption that stride is a power of 2.
191
+ min_level = -(math.log2(scales[0]))
192
+ max_level = -(math.log2(scales[-1]))
193
+ assert math.isclose(min_level, int(min_level)) and math.isclose(
194
+ max_level, int(max_level)
195
+ ), "Featuremap stride is not power of 2!"
196
+ self.min_level = int(min_level)
197
+ self.max_level = int(max_level)
198
+ assert (
199
+ len(scales) == self.max_level - self.min_level + 1
200
+ ), "[ROIPooler] Sizes of input featuremaps do not form a pyramid!"
201
+ assert 0 <= self.min_level and self.min_level <= self.max_level
202
+ self.canonical_level = canonical_level
203
+ assert canonical_box_size > 0
204
+ self.canonical_box_size = canonical_box_size
205
+
206
+ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
207
+ """
208
+ Args:
209
+ x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
210
+ used to construct this module.
211
+ box_lists (list[Boxes] | list[RotatedBoxes]):
212
+ A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch.
213
+ The box coordinates are defined on the original image and
214
+ will be scaled by the `scales` argument of :class:`ROIPooler`.
215
+
216
+ Returns:
217
+ Tensor:
218
+ A tensor of shape (M, C, output_size, output_size) where M is the total number of
219
+ boxes aggregated over all N batch images and C is the number of channels in `x`.
220
+ """
221
+ num_level_assignments = len(self.level_poolers)
222
+
223
+ if not is_fx_tracing():
224
+ torch._assert(
225
+ isinstance(x, list) and isinstance(box_lists, list),
226
+ "Arguments to pooler must be lists",
227
+ )
228
+ assert_fx_safe(
229
+ len(x) == num_level_assignments,
230
+ "unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
231
+ num_level_assignments, len(x)
232
+ ),
233
+ )
234
+ assert_fx_safe(
235
+ len(box_lists) == x[0].size(0),
236
+ "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
237
+ x[0].size(0), len(box_lists)
238
+ ),
239
+ )
240
+ if len(box_lists) == 0:
241
+ return _create_zeros(None, x[0].shape[1], *self.output_size, x[0])
242
+
243
+ pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists)
244
+
245
+ if num_level_assignments == 1:
246
+ return self.level_poolers[0](x[0], pooler_fmt_boxes)
247
+
248
+ level_assignments = assign_boxes_to_levels(
249
+ box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
250
+ )
251
+
252
+ num_channels = x[0].shape[1]
253
+ output_size = self.output_size[0]
254
+
255
+ output = _create_zeros(pooler_fmt_boxes, num_channels, output_size, output_size, x[0])
256
+
257
+ for level, pooler in enumerate(self.level_poolers):
258
+ inds = nonzero_tuple(level_assignments == level)[0]
259
+ pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
260
+ # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
261
+ output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level))
262
+
263
+ return output
CatVTON/detectron2/projects/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ Projects live in the [`projects` directory](../../projects) under the root of this repository, but not here.
CatVTON/detectron2/projects/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import importlib.abc
3
+ import importlib.util
4
+ from pathlib import Path
5
+
6
+ __all__ = []
7
+
8
+ _PROJECTS = {
9
+ "point_rend": "PointRend",
10
+ "deeplab": "DeepLab",
11
+ "panoptic_deeplab": "Panoptic-DeepLab",
12
+ }
13
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent / "projects"
14
+
15
+ if _PROJECT_ROOT.is_dir():
16
+ # This is true only for in-place installation (pip install -e, setup.py develop),
17
+ # where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230
18
+
19
+ class _D2ProjectsFinder(importlib.abc.MetaPathFinder):
20
+ def find_spec(self, name, path, target=None):
21
+ if not name.startswith("detectron2.projects."):
22
+ return
23
+ project_name = name.split(".")[-1]
24
+ project_dir = _PROJECTS.get(project_name)
25
+ if not project_dir:
26
+ return
27
+ target_file = _PROJECT_ROOT / f"{project_dir}/{project_name}/__init__.py"
28
+ if not target_file.is_file():
29
+ return
30
+ return importlib.util.spec_from_file_location(name, target_file)
31
+
32
+ import sys
33
+
34
+ sys.meta_path.append(_D2ProjectsFinder())
CatVTON/detectron2/solver/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
3
+ from .lr_scheduler import (
4
+ LRMultiplier,
5
+ LRScheduler,
6
+ WarmupCosineLR,
7
+ WarmupMultiStepLR,
8
+ WarmupParamScheduler,
9
+ )
10
+
11
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
CatVTON/detectron2/solver/build.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import itertools
4
+ import logging
5
+ from collections import defaultdict
6
+ from enum import Enum
7
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
8
+ import torch
9
+ from fvcore.common.param_scheduler import (
10
+ CosineParamScheduler,
11
+ MultiStepParamScheduler,
12
+ StepWithFixedGammaParamScheduler,
13
+ )
14
+
15
+ from detectron2.config import CfgNode
16
+ from detectron2.utils.env import TORCH_VERSION
17
+
18
+ from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler
19
+
20
+ _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
21
+ _GradientClipper = Callable[[_GradientClipperInput], None]
22
+
23
+
24
+ class GradientClipType(Enum):
25
+ VALUE = "value"
26
+ NORM = "norm"
27
+
28
+
29
+ def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
30
+ """
31
+ Creates gradient clipping closure to clip by value or by norm,
32
+ according to the provided config.
33
+ """
34
+ cfg = copy.deepcopy(cfg)
35
+
36
+ def clip_grad_norm(p: _GradientClipperInput):
37
+ torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
38
+
39
+ def clip_grad_value(p: _GradientClipperInput):
40
+ torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
41
+
42
+ _GRADIENT_CLIP_TYPE_TO_CLIPPER = {
43
+ GradientClipType.VALUE: clip_grad_value,
44
+ GradientClipType.NORM: clip_grad_norm,
45
+ }
46
+ return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
47
+
48
+
49
+ def _generate_optimizer_class_with_gradient_clipping(
50
+ optimizer: Type[torch.optim.Optimizer],
51
+ *,
52
+ per_param_clipper: Optional[_GradientClipper] = None,
53
+ global_clipper: Optional[_GradientClipper] = None,
54
+ ) -> Type[torch.optim.Optimizer]:
55
+ """
56
+ Dynamically creates a new type that inherits the type of a given instance
57
+ and overrides the `step` method to add gradient clipping
58
+ """
59
+ assert (
60
+ per_param_clipper is None or global_clipper is None
61
+ ), "Not allowed to use both per-parameter clipping and global clipping"
62
+
63
+ def optimizer_wgc_step(self, closure=None):
64
+ if per_param_clipper is not None:
65
+ for group in self.param_groups:
66
+ for p in group["params"]:
67
+ per_param_clipper(p)
68
+ else:
69
+ # global clipper for future use with detr
70
+ # (https://github.com/facebookresearch/detr/pull/287)
71
+ all_params = itertools.chain(*[g["params"] for g in self.param_groups])
72
+ global_clipper(all_params)
73
+ super(type(self), self).step(closure)
74
+
75
+ OptimizerWithGradientClip = type(
76
+ optimizer.__name__ + "WithGradientClip",
77
+ (optimizer,),
78
+ {"step": optimizer_wgc_step},
79
+ )
80
+ return OptimizerWithGradientClip
81
+
82
+
83
+ def maybe_add_gradient_clipping(
84
+ cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
85
+ ) -> Type[torch.optim.Optimizer]:
86
+ """
87
+ If gradient clipping is enabled through config options, wraps the existing
88
+ optimizer type to become a new dynamically created class OptimizerWithGradientClip
89
+ that inherits the given optimizer and overrides the `step` method to
90
+ include gradient clipping.
91
+
92
+ Args:
93
+ cfg: CfgNode, configuration options
94
+ optimizer: type. A subclass of torch.optim.Optimizer
95
+
96
+ Return:
97
+ type: either the input `optimizer` (if gradient clipping is disabled), or
98
+ a subclass of it with gradient clipping included in the `step` method.
99
+ """
100
+ if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
101
+ return optimizer
102
+ if isinstance(optimizer, torch.optim.Optimizer):
103
+ optimizer_type = type(optimizer)
104
+ else:
105
+ assert issubclass(optimizer, torch.optim.Optimizer), optimizer
106
+ optimizer_type = optimizer
107
+
108
+ grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
109
+ OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
110
+ optimizer_type, per_param_clipper=grad_clipper
111
+ )
112
+ if isinstance(optimizer, torch.optim.Optimizer):
113
+ optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
114
+ return optimizer
115
+ else:
116
+ return OptimizerWithGradientClip
117
+
118
+
119
+ def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
120
+ """
121
+ Build an optimizer from config.
122
+ """
123
+ params = get_default_optimizer_params(
124
+ model,
125
+ base_lr=cfg.SOLVER.BASE_LR,
126
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
127
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
128
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
129
+ )
130
+ sgd_args = {
131
+ "params": params,
132
+ "lr": cfg.SOLVER.BASE_LR,
133
+ "momentum": cfg.SOLVER.MOMENTUM,
134
+ "nesterov": cfg.SOLVER.NESTEROV,
135
+ "weight_decay": cfg.SOLVER.WEIGHT_DECAY,
136
+ }
137
+ if TORCH_VERSION >= (1, 12):
138
+ sgd_args["foreach"] = True
139
+ return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args))
140
+
141
+
142
+ def get_default_optimizer_params(
143
+ model: torch.nn.Module,
144
+ base_lr: Optional[float] = None,
145
+ weight_decay: Optional[float] = None,
146
+ weight_decay_norm: Optional[float] = None,
147
+ bias_lr_factor: Optional[float] = 1.0,
148
+ weight_decay_bias: Optional[float] = None,
149
+ lr_factor_func: Optional[Callable] = None,
150
+ overrides: Optional[Dict[str, Dict[str, float]]] = None,
151
+ ) -> List[Dict[str, Any]]:
152
+ """
153
+ Get default param list for optimizer, with support for a few types of
154
+ overrides. If no overrides needed, this is equivalent to `model.parameters()`.
155
+
156
+ Args:
157
+ base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
158
+ weight_decay: weight decay for every group by default. Can be omitted to use the one
159
+ in optimizer.
160
+ weight_decay_norm: override weight decay for params in normalization layers
161
+ bias_lr_factor: multiplier of lr for bias parameters.
162
+ weight_decay_bias: override weight decay for bias parameters.
163
+ lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
164
+ corresponding lr decay rate. Note that setting this option requires
165
+ also setting ``base_lr``.
166
+ overrides: if not `None`, provides values for optimizer hyperparameters
167
+ (LR, weight decay) for module parameters with a given name; e.g.
168
+ ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
169
+ weight decay values for all module parameters named `embedding`.
170
+
171
+ For common detection models, ``weight_decay_norm`` is the only option
172
+ needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
173
+ from Detectron1 that are not found useful.
174
+
175
+ Example:
176
+ ::
177
+ torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
178
+ lr=0.01, weight_decay=1e-4, momentum=0.9)
179
+ """
180
+ if overrides is None:
181
+ overrides = {}
182
+ defaults = {}
183
+ if base_lr is not None:
184
+ defaults["lr"] = base_lr
185
+ if weight_decay is not None:
186
+ defaults["weight_decay"] = weight_decay
187
+ bias_overrides = {}
188
+ if bias_lr_factor is not None and bias_lr_factor != 1.0:
189
+ # NOTE: unlike Detectron v1, we now by default make bias hyperparameters
190
+ # exactly the same as regular weights.
191
+ if base_lr is None:
192
+ raise ValueError("bias_lr_factor requires base_lr")
193
+ bias_overrides["lr"] = base_lr * bias_lr_factor
194
+ if weight_decay_bias is not None:
195
+ bias_overrides["weight_decay"] = weight_decay_bias
196
+ if len(bias_overrides):
197
+ if "bias" in overrides:
198
+ raise ValueError("Conflicting overrides for 'bias'")
199
+ overrides["bias"] = bias_overrides
200
+ if lr_factor_func is not None:
201
+ if base_lr is None:
202
+ raise ValueError("lr_factor_func requires base_lr")
203
+ norm_module_types = (
204
+ torch.nn.BatchNorm1d,
205
+ torch.nn.BatchNorm2d,
206
+ torch.nn.BatchNorm3d,
207
+ torch.nn.SyncBatchNorm,
208
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
209
+ torch.nn.GroupNorm,
210
+ torch.nn.InstanceNorm1d,
211
+ torch.nn.InstanceNorm2d,
212
+ torch.nn.InstanceNorm3d,
213
+ torch.nn.LayerNorm,
214
+ torch.nn.LocalResponseNorm,
215
+ )
216
+ params: List[Dict[str, Any]] = []
217
+ memo: Set[torch.nn.parameter.Parameter] = set()
218
+ for module_name, module in model.named_modules():
219
+ for module_param_name, value in module.named_parameters(recurse=False):
220
+ if not value.requires_grad:
221
+ continue
222
+ # Avoid duplicating parameters
223
+ if value in memo:
224
+ continue
225
+ memo.add(value)
226
+
227
+ hyperparams = copy.copy(defaults)
228
+ if isinstance(module, norm_module_types) and weight_decay_norm is not None:
229
+ hyperparams["weight_decay"] = weight_decay_norm
230
+ if lr_factor_func is not None:
231
+ hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
232
+
233
+ hyperparams.update(overrides.get(module_param_name, {}))
234
+ params.append({"params": [value], **hyperparams})
235
+ return reduce_param_groups(params)
236
+
237
+
238
+ def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
239
+ # Transform parameter groups into per-parameter structure.
240
+ # Later items in `params` can overwrite parameters set in previous items.
241
+ ret = defaultdict(dict)
242
+ for item in params:
243
+ assert "params" in item
244
+ cur_params = {x: y for x, y in item.items() if x != "params" and x != "param_names"}
245
+ if "param_names" in item:
246
+ for param_name, param in zip(item["param_names"], item["params"]):
247
+ ret[param].update({"param_names": [param_name], "params": [param], **cur_params})
248
+ else:
249
+ for param in item["params"]:
250
+ ret[param].update({"params": [param], **cur_params})
251
+ return list(ret.values())
252
+
253
+
254
+ def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
255
+ # Reorganize the parameter groups and merge duplicated groups.
256
+ # The number of parameter groups needs to be as small as possible in order
257
+ # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
258
+ # of using a parameter_group per single parameter, we reorganize the
259
+ # parameter groups and merge duplicated groups. This approach speeds
260
+ # up multi-tensor optimizer significantly.
261
+ params = _expand_param_groups(params)
262
+ groups = defaultdict(list) # re-group all parameter groups by their hyperparams
263
+ for item in params:
264
+ cur_params = tuple((x, y) for x, y in item.items() if x != "params" and x != "param_names")
265
+ groups[cur_params].append({"params": item["params"]})
266
+ if "param_names" in item:
267
+ groups[cur_params][-1]["param_names"] = item["param_names"]
268
+
269
+ ret = []
270
+ for param_keys, param_values in groups.items():
271
+ cur = {kv[0]: kv[1] for kv in param_keys}
272
+ cur["params"] = list(
273
+ itertools.chain.from_iterable([params["params"] for params in param_values])
274
+ )
275
+ if len(param_values) > 0 and "param_names" in param_values[0]:
276
+ cur["param_names"] = list(
277
+ itertools.chain.from_iterable([params["param_names"] for params in param_values])
278
+ )
279
+ ret.append(cur)
280
+ return ret
281
+
282
+
283
+ def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
284
+ """
285
+ Build a LR scheduler from config.
286
+ """
287
+ name = cfg.SOLVER.LR_SCHEDULER_NAME
288
+
289
+ if name == "WarmupMultiStepLR":
290
+ steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
291
+ if len(steps) != len(cfg.SOLVER.STEPS):
292
+ logger = logging.getLogger(__name__)
293
+ logger.warning(
294
+ "SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
295
+ "These values will be ignored."
296
+ )
297
+ sched = MultiStepParamScheduler(
298
+ values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)],
299
+ milestones=steps,
300
+ num_updates=cfg.SOLVER.MAX_ITER,
301
+ )
302
+ elif name == "WarmupCosineLR":
303
+ end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
304
+ assert end_value >= 0.0 and end_value <= 1.0, end_value
305
+ sched = CosineParamScheduler(1, end_value)
306
+ elif name == "WarmupStepWithFixedGammaLR":
307
+ sched = StepWithFixedGammaParamScheduler(
308
+ base_value=1.0,
309
+ gamma=cfg.SOLVER.GAMMA,
310
+ num_decays=cfg.SOLVER.NUM_DECAYS,
311
+ num_updates=cfg.SOLVER.MAX_ITER,
312
+ )
313
+ else:
314
+ raise ValueError("Unknown LR scheduler: {}".format(name))
315
+
316
+ sched = WarmupParamScheduler(
317
+ sched,
318
+ cfg.SOLVER.WARMUP_FACTOR,
319
+ min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
320
+ cfg.SOLVER.WARMUP_METHOD,
321
+ cfg.SOLVER.RESCALE_INTERVAL,
322
+ )
323
+ return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
CatVTON/detectron2/solver/lr_scheduler.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import math
4
+ from bisect import bisect_right
5
+ from typing import List
6
+ import torch
7
+ from fvcore.common.param_scheduler import (
8
+ CompositeParamScheduler,
9
+ ConstantParamScheduler,
10
+ LinearParamScheduler,
11
+ ParamScheduler,
12
+ )
13
+
14
+ try:
15
+ from torch.optim.lr_scheduler import LRScheduler
16
+ except ImportError:
17
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class WarmupParamScheduler(CompositeParamScheduler):
23
+ """
24
+ Add an initial warmup stage to another scheduler.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ scheduler: ParamScheduler,
30
+ warmup_factor: float,
31
+ warmup_length: float,
32
+ warmup_method: str = "linear",
33
+ rescale_interval: bool = False,
34
+ ):
35
+ """
36
+ Args:
37
+ scheduler: warmup will be added at the beginning of this scheduler
38
+ warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001
39
+ warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
40
+ training, e.g. 0.01
41
+ warmup_method: one of "linear" or "constant"
42
+ rescale_interval: whether we will rescale the interval of the scheduler after
43
+ warmup
44
+ """
45
+ # the value to reach when warmup ends
46
+ end_value = scheduler(0.0) if rescale_interval else scheduler(warmup_length)
47
+ start_value = warmup_factor * scheduler(0.0)
48
+ if warmup_method == "constant":
49
+ warmup = ConstantParamScheduler(start_value)
50
+ elif warmup_method == "linear":
51
+ warmup = LinearParamScheduler(start_value, end_value)
52
+ else:
53
+ raise ValueError("Unknown warmup method: {}".format(warmup_method))
54
+ super().__init__(
55
+ [warmup, scheduler],
56
+ interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"],
57
+ lengths=[warmup_length, 1 - warmup_length],
58
+ )
59
+
60
+
61
+ class LRMultiplier(LRScheduler):
62
+ """
63
+ A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
64
+ learning rate of each param in the optimizer.
65
+ Every step, the learning rate of each parameter becomes its initial value
66
+ multiplied by the output of the given :class:`ParamScheduler`.
67
+
68
+ The absolute learning rate value of each parameter can be different.
69
+ This scheduler can be used as long as the relative scale among them do
70
+ not change during training.
71
+
72
+ Examples:
73
+ ::
74
+ LRMultiplier(
75
+ opt,
76
+ WarmupParamScheduler(
77
+ MultiStepParamScheduler(
78
+ [1, 0.1, 0.01],
79
+ milestones=[60000, 80000],
80
+ num_updates=90000,
81
+ ), 0.001, 100 / 90000
82
+ ),
83
+ max_iter=90000
84
+ )
85
+ """
86
+
87
+ # NOTES: in the most general case, every LR can use its own scheduler.
88
+ # Supporting this requires interaction with the optimizer when its parameter
89
+ # group is initialized. For example, classyvision implements its own optimizer
90
+ # that allows different schedulers for every parameter group.
91
+ # To avoid this complexity, we use this class to support the most common cases
92
+ # where the relative scale among all LRs stay unchanged during training. In this
93
+ # case we only need a total of one scheduler that defines the relative LR multiplier.
94
+
95
+ def __init__(
96
+ self,
97
+ optimizer: torch.optim.Optimizer,
98
+ multiplier: ParamScheduler,
99
+ max_iter: int,
100
+ last_iter: int = -1,
101
+ ):
102
+ """
103
+ Args:
104
+ optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``.
105
+ ``last_iter`` is the same as ``last_epoch``.
106
+ multiplier: a fvcore ParamScheduler that defines the multiplier on
107
+ every LR of the optimizer
108
+ max_iter: the total number of training iterations
109
+ """
110
+ if not isinstance(multiplier, ParamScheduler):
111
+ raise ValueError(
112
+ "_LRMultiplier(multiplier=) must be an instance of fvcore "
113
+ f"ParamScheduler. Got {multiplier} instead."
114
+ )
115
+ self._multiplier = multiplier
116
+ self._max_iter = max_iter
117
+ super().__init__(optimizer, last_epoch=last_iter)
118
+
119
+ def state_dict(self):
120
+ # fvcore schedulers are stateless. Only keep pytorch scheduler states
121
+ return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
122
+
123
+ def get_lr(self) -> List[float]:
124
+ multiplier = self._multiplier(self.last_epoch / self._max_iter)
125
+ return [base_lr * multiplier for base_lr in self.base_lrs]
126
+
127
+
128
+ """
129
+ Content below is no longer needed!
130
+ """
131
+
132
+ # NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
133
+ # only on epoch boundaries. We typically use iteration based schedules instead.
134
+ # As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
135
+ # "iteration" instead.
136
+
137
+ # FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
138
+ # MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
139
+
140
+
141
+ class WarmupMultiStepLR(LRScheduler):
142
+ def __init__(
143
+ self,
144
+ optimizer: torch.optim.Optimizer,
145
+ milestones: List[int],
146
+ gamma: float = 0.1,
147
+ warmup_factor: float = 0.001,
148
+ warmup_iters: int = 1000,
149
+ warmup_method: str = "linear",
150
+ last_epoch: int = -1,
151
+ ):
152
+ logger.warning(
153
+ "WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
154
+ )
155
+ if not list(milestones) == sorted(milestones):
156
+ raise ValueError(
157
+ "Milestones should be a list of" " increasing integers. Got {}", milestones
158
+ )
159
+ self.milestones = milestones
160
+ self.gamma = gamma
161
+ self.warmup_factor = warmup_factor
162
+ self.warmup_iters = warmup_iters
163
+ self.warmup_method = warmup_method
164
+ super().__init__(optimizer, last_epoch)
165
+
166
+ def get_lr(self) -> List[float]:
167
+ warmup_factor = _get_warmup_factor_at_iter(
168
+ self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
169
+ )
170
+ return [
171
+ base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
172
+ for base_lr in self.base_lrs
173
+ ]
174
+
175
+ def _compute_values(self) -> List[float]:
176
+ # The new interface
177
+ return self.get_lr()
178
+
179
+
180
+ class WarmupCosineLR(LRScheduler):
181
+ def __init__(
182
+ self,
183
+ optimizer: torch.optim.Optimizer,
184
+ max_iters: int,
185
+ warmup_factor: float = 0.001,
186
+ warmup_iters: int = 1000,
187
+ warmup_method: str = "linear",
188
+ last_epoch: int = -1,
189
+ ):
190
+ logger.warning(
191
+ "WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
192
+ )
193
+ self.max_iters = max_iters
194
+ self.warmup_factor = warmup_factor
195
+ self.warmup_iters = warmup_iters
196
+ self.warmup_method = warmup_method
197
+ super().__init__(optimizer, last_epoch)
198
+
199
+ def get_lr(self) -> List[float]:
200
+ warmup_factor = _get_warmup_factor_at_iter(
201
+ self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
202
+ )
203
+ # Different definitions of half-cosine with warmup are possible. For
204
+ # simplicity we multiply the standard half-cosine schedule by the warmup
205
+ # factor. An alternative is to start the period of the cosine at warmup_iters
206
+ # instead of at 0. In the case that warmup_iters << max_iters the two are
207
+ # very close to each other.
208
+ return [
209
+ base_lr
210
+ * warmup_factor
211
+ * 0.5
212
+ * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters))
213
+ for base_lr in self.base_lrs
214
+ ]
215
+
216
+ def _compute_values(self) -> List[float]:
217
+ # The new interface
218
+ return self.get_lr()
219
+
220
+
221
+ def _get_warmup_factor_at_iter(
222
+ method: str, iter: int, warmup_iters: int, warmup_factor: float
223
+ ) -> float:
224
+ """
225
+ Return the learning rate warmup factor at a specific iteration.
226
+ See :paper:`ImageNet in 1h` for more details.
227
+
228
+ Args:
229
+ method (str): warmup method; either "constant" or "linear".
230
+ iter (int): iteration at which to calculate the warmup factor.
231
+ warmup_iters (int): the number of warmup iterations.
232
+ warmup_factor (float): the base warmup factor (the meaning changes according
233
+ to the method used).
234
+
235
+ Returns:
236
+ float: the effective warmup factor at the given iteration.
237
+ """
238
+ if iter >= warmup_iters:
239
+ return 1.0
240
+
241
+ if method == "constant":
242
+ return warmup_factor
243
+ elif method == "linear":
244
+ alpha = iter / warmup_iters
245
+ return warmup_factor * (1 - alpha) + alpha
246
+ else:
247
+ raise ValueError("Unknown warmup method: {}".format(method))