luoxue-star commited on
Commit
1966925
·
1 Parent(s): 7a90a56

Init Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. README.md +11 -0
  3. amr/__init__.py +0 -0
  4. amr/__pycache__/__init__.cpython-310.pyc +0 -0
  5. amr/configs/__init__.py +112 -0
  6. amr/configs/__pycache__/__init__.cpython-310.pyc +0 -0
  7. amr/datasets/__init__.py +0 -0
  8. amr/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  9. amr/datasets/__pycache__/utils.cpython-310.pyc +0 -0
  10. amr/datasets/__pycache__/vitdet_dataset.cpython-310.pyc +0 -0
  11. amr/datasets/utils.py +1098 -0
  12. amr/datasets/vitdet_dataset.py +102 -0
  13. amr/models/__init__.py +26 -0
  14. amr/models/__pycache__/__init__.cpython-310.pyc +0 -0
  15. amr/models/__pycache__/__init__.cpython-312.pyc +0 -0
  16. amr/models/__pycache__/__init__.cpython-39.pyc +0 -0
  17. amr/models/__pycache__/amr.cpython-310.pyc +0 -0
  18. amr/models/__pycache__/amr.cpython-312.pyc +0 -0
  19. amr/models/__pycache__/amr.cpython-39.pyc +0 -0
  20. amr/models/__pycache__/animerpp.cpython-310.pyc +0 -0
  21. amr/models/__pycache__/animerpp.cpython-312.pyc +0 -0
  22. amr/models/__pycache__/aves_hmr.cpython-310.pyc +0 -0
  23. amr/models/__pycache__/aves_hmr.cpython-312.pyc +0 -0
  24. amr/models/__pycache__/aves_warapper.cpython-310.pyc +0 -0
  25. amr/models/__pycache__/aves_warapper.cpython-312.pyc +0 -0
  26. amr/models/__pycache__/discriminator.cpython-310.pyc +0 -0
  27. amr/models/__pycache__/discriminator.cpython-312.pyc +0 -0
  28. amr/models/__pycache__/discriminator.cpython-39.pyc +0 -0
  29. amr/models/__pycache__/dyamr.cpython-310.pyc +0 -0
  30. amr/models/__pycache__/dyamr.cpython-312.pyc +0 -0
  31. amr/models/__pycache__/losses.cpython-310.pyc +0 -0
  32. amr/models/__pycache__/losses.cpython-312.pyc +0 -0
  33. amr/models/__pycache__/losses.cpython-39.pyc +0 -0
  34. amr/models/__pycache__/predictor.cpython-310.pyc +0 -0
  35. amr/models/__pycache__/smal_warapper.cpython-310.pyc +0 -0
  36. amr/models/__pycache__/smal_warapper.cpython-312.pyc +0 -0
  37. amr/models/__pycache__/smal_warapper.cpython-39.pyc +0 -0
  38. amr/models/__pycache__/smooth_amr.cpython-310.pyc +0 -0
  39. amr/models/__pycache__/smooth_amr.cpython-312.pyc +0 -0
  40. amr/models/__pycache__/smooth_netv2.cpython-310.pyc +0 -0
  41. amr/models/__pycache__/stamr.cpython-310.pyc +0 -0
  42. amr/models/__pycache__/stamr.cpython-312.pyc +0 -0
  43. amr/models/animerpp.py +508 -0
  44. amr/models/aves_warapper.py +136 -0
  45. amr/models/backbones/__init__.py +9 -0
  46. amr/models/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
  47. amr/models/backbones/__pycache__/__init__.cpython-312.pyc +0 -0
  48. amr/models/backbones/__pycache__/__init__.cpython-39.pyc +0 -0
  49. amr/models/backbones/__pycache__/rope_deit.cpython-310.pyc +0 -0
  50. amr/models/backbones/__pycache__/vit.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.JPEG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,5 +1,6 @@
1
  ---
2
  title: AniMerPlus
 
3
  emoji: 📊
4
  colorFrom: gray
5
  colorTo: indigo
@@ -7,6 +8,16 @@ sdk: gradio
7
  sdk_version: 5.39.0
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: AniMerPlus
3
+ <<<<<<< HEAD
4
  emoji: 📊
5
  colorFrom: gray
6
  colorTo: indigo
 
8
  sdk_version: 5.39.0
9
  app_file: app.py
10
  pinned: false
11
+ =======
12
+ emoji: 🔥
13
+ colorFrom: pink
14
+ colorTo: blue
15
+ sdk: gradio
16
+ sdk_version: 5.1.0
17
+ app_file: app.py
18
+ pinned: false
19
+ license: mit
20
+ >>>>>>> Init Commit
21
  ---
22
 
23
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
amr/__init__.py ADDED
File without changes
amr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
amr/configs/__init__.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from yacs.config import CfgNode as CN
4
+
5
+ CACHE_DIR_AniMer = "./_DATA"
6
+
7
+
8
+ def to_lower(x: Dict) -> Dict:
9
+ """
10
+ Convert all dictionary keys to lowercase
11
+ Args:
12
+ x (dict): Input dictionary
13
+ Returns:
14
+ dict: Output dictionary with all keys converted to lowercase
15
+ """
16
+ return {k.lower(): v for k, v in x.items()}
17
+
18
+
19
+ _C = CN(new_allowed=True)
20
+
21
+ _C.GENERAL = CN(new_allowed=True)
22
+ _C.GENERAL.RESUME = True
23
+ _C.GENERAL.TIME_TO_RUN = 3300
24
+ _C.GENERAL.VAL_STEPS = 100
25
+ _C.GENERAL.LOG_STEPS = 100
26
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
27
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
28
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
29
+ _C.GENERAL.NUM_GPUS = 1
30
+ _C.GENERAL.NUM_WORKERS = 4
31
+ _C.GENERAL.MIXED_PRECISION = True
32
+ _C.GENERAL.ALLOW_CUDA = True
33
+ _C.GENERAL.PIN_MEMORY = False
34
+ _C.GENERAL.DISTRIBUTED = False
35
+ _C.GENERAL.LOCAL_RANK = 0
36
+ _C.GENERAL.USE_SYNCBN = False
37
+ _C.GENERAL.WORLD_SIZE = 1
38
+
39
+ _C.TRAIN = CN(new_allowed=True)
40
+ _C.TRAIN.NUM_EPOCHS = 100
41
+ _C.TRAIN.SHUFFLE = True
42
+ _C.TRAIN.WARMUP = False
43
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
44
+ _C.TRAIN.CLIP_GRAD = False
45
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
46
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
47
+
48
+ _C.DATASETS = CN(new_allowed=True)
49
+
50
+ _C.MODEL = CN(new_allowed=True)
51
+ _C.MODEL.IMAGE_SIZE = 224
52
+
53
+ _C.EXTRA = CN(new_allowed=True)
54
+ _C.EXTRA.FOCAL_LENGTH = 5000
55
+
56
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
57
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
58
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
59
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
60
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
61
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
62
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
63
+ _C.DATASETS.CONFIG.DO_FLIP = False
64
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
65
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
66
+
67
+
68
+ def default_config() -> CN:
69
+ """
70
+ Get a yacs CfgNode object with the default config values.
71
+ """
72
+ # Return a clone so that the defaults will not be altered
73
+ # This is for the "local variable" use pattern
74
+ return _C.clone()
75
+
76
+
77
+ def dataset_config() -> CN:
78
+ """
79
+ Get dataset config file
80
+ Returns:
81
+ CfgNode: Dataset config as a yacs CfgNode object.
82
+ """
83
+ cfg = CN(new_allowed=True)
84
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
85
+ cfg.merge_from_file(config_file)
86
+ cfg.freeze()
87
+ return cfg
88
+
89
+
90
+ def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
91
+ """
92
+ Read a config file and optionally merge it with the default config file.
93
+ Args:
94
+ config_file (str): Path to config file.
95
+ merge (bool): Whether to merge with the default config or not.
96
+ Returns:
97
+ CfgNode: Config as a yacs CfgNode object.
98
+ """
99
+ if merge:
100
+ cfg = default_config()
101
+ else:
102
+ cfg = CN(new_allowed=True)
103
+ cfg.merge_from_file(config_file)
104
+
105
+ if update_cachedir:
106
+ def update_path(path: str) -> str:
107
+ if os.path.isabs(path):
108
+ return path
109
+ return os.path.join(CACHE_DIR_AniMer, path)
110
+
111
+ cfg.freeze()
112
+ return cfg
amr/configs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
amr/datasets/__init__.py ADDED
File without changes
amr/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
amr/datasets/__pycache__/utils.cpython-310.pyc ADDED
Binary file (36 kB). View file
 
amr/datasets/__pycache__/vitdet_dataset.cpython-310.pyc ADDED
Binary file (3.19 kB). View file
 
amr/datasets/utils.py ADDED
@@ -0,0 +1,1098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of the code are taken or adapted from
3
+ https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from skimage.transform import rotate, resize
8
+ from skimage.filters import gaussian
9
+ import random
10
+ import cv2
11
+ from typing import List, Dict, Tuple
12
+ from yacs.config import CfgNode
13
+ from typing import Union
14
+
15
+
16
+ def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
17
+ """Increase the size of the bounding box to match the target shape."""
18
+ if target_aspect_ratio is None:
19
+ return input_shape
20
+
21
+ try:
22
+ w, h = input_shape
23
+ except (ValueError, TypeError):
24
+ return input_shape
25
+
26
+ w_t, h_t = target_aspect_ratio
27
+ if h / w < h_t / w_t:
28
+ h_new = max(w * h_t / w_t, h)
29
+ w_new = w
30
+ else:
31
+ h_new = h
32
+ w_new = max(h * w_t / h_t, w)
33
+ if h_new < h or w_new < w:
34
+ breakpoint()
35
+ return np.array([w_new, h_new])
36
+
37
+
38
+ def do_augmentation(aug_config: CfgNode) -> Tuple:
39
+ """
40
+ Compute random augmentation parameters.
41
+ Args:
42
+ aug_config (CfgNode): Config containing augmentation parameters.
43
+ Returns:
44
+ scale (float): Box rescaling factor.
45
+ rot (float): Random image rotation.
46
+ do_flip (bool): Whether to flip image or not.
47
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
48
+ color_scale (List): Color rescaling factor
49
+ tx (float): Random translation along the x axis.
50
+ ty (float): Random translation along the y axis.
51
+ """
52
+
53
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
54
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
55
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
56
+ rot = np.clip(np.random.randn(), -2.0,
57
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
58
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
59
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
60
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
61
+ # extreme_crop_lvl = 0
62
+ c_up = 1.0 + aug_config.COLOR_SCALE
63
+ c_low = 1.0 - aug_config.COLOR_SCALE
64
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
65
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
66
+
67
+
68
+ def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
69
+ """
70
+ Rotate a 2D point on the x-y plane.
71
+ Args:
72
+ pt_2d (np.array): Input 2D point with shape (2,).
73
+ rot_rad (float): Rotation angle
74
+ Returns:
75
+ np.array: Rotated 2D point.
76
+ """
77
+ x = pt_2d[0]
78
+ y = pt_2d[1]
79
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
80
+ xx = x * cs - y * sn
81
+ yy = x * sn + y * cs
82
+ return np.array([xx, yy], dtype=np.float32)
83
+
84
+
85
+ def gen_trans_from_patch_cv(c_x: float, c_y: float,
86
+ src_width: float, src_height: float,
87
+ dst_width: float, dst_height: float,
88
+ scale: float, rot: float) -> np.array:
89
+ """
90
+ Create transformation matrix for the bounding box crop.
91
+ Args:
92
+ c_x (float): Bounding box center x coordinate in the original image.
93
+ c_y (float): Bounding box center y coordinate in the original image.
94
+ src_width (float): Bounding box width.
95
+ src_height (float): Bounding box height.
96
+ dst_width (float): Output box width.
97
+ dst_height (float): Output box height.
98
+ scale (float): Rescaling factor for the bounding box (augmentation).
99
+ rot (float): Random rotation applied to the box.
100
+ Returns:
101
+ trans (np.array): Target geometric transformation.
102
+ """
103
+ # augment size with scale
104
+ src_w = src_width * scale
105
+ src_h = src_height * scale
106
+ src_center = np.zeros(2)
107
+ src_center[0] = c_x
108
+ src_center[1] = c_y
109
+ # augment rotation
110
+ rot_rad = np.pi * rot / 180
111
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
112
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
113
+
114
+ dst_w = dst_width
115
+ dst_h = dst_height
116
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
117
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
118
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
119
+
120
+ src = np.zeros((3, 2), dtype=np.float32)
121
+ src[0, :] = src_center
122
+ src[1, :] = src_center + src_downdir
123
+ src[2, :] = src_center + src_rightdir
124
+
125
+ dst = np.zeros((3, 2), dtype=np.float32)
126
+ dst[0, :] = dst_center
127
+ dst[1, :] = dst_center + dst_downdir
128
+ dst[2, :] = dst_center + dst_rightdir
129
+
130
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
131
+
132
+ return trans
133
+
134
+
135
+ def trans_point2d(pt_2d: np.array, trans: np.array):
136
+ """
137
+ Transform a 2D point using translation matrix trans.
138
+ Args:
139
+ pt_2d (np.array): Input 2D point with shape (2,).
140
+ trans (np.array): Transformation matrix.
141
+ Returns:
142
+ np.array: Transformed 2D point.
143
+ """
144
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
145
+ dst_pt = np.dot(trans, src_pt)
146
+ return dst_pt[0:2]
147
+
148
+
149
+ def get_transform(center, scale, res, rot=0):
150
+ """Generate transformation matrix."""
151
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
152
+ h = 200 * scale
153
+ t = np.zeros((3, 3))
154
+ t[0, 0] = float(res[1]) / h
155
+ t[1, 1] = float(res[0]) / h
156
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
157
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
158
+ t[2, 2] = 1
159
+ if not rot == 0:
160
+ rot = -rot # To match direction of rotation from cropping
161
+ rot_mat = np.zeros((3, 3))
162
+ rot_rad = rot * np.pi / 180
163
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
164
+ rot_mat[0, :2] = [cs, -sn]
165
+ rot_mat[1, :2] = [sn, cs]
166
+ rot_mat[2, 2] = 1
167
+ # Need to rotate around center
168
+ t_mat = np.eye(3)
169
+ t_mat[0, 2] = -res[1] / 2
170
+ t_mat[1, 2] = -res[0] / 2
171
+ t_inv = t_mat.copy()
172
+ t_inv[:2, 2] *= -1
173
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
174
+ return t
175
+
176
+
177
+ def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
178
+ """Transform pixel location to different reference."""
179
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
180
+ t = get_transform(center, scale, res, rot=rot)
181
+ if invert:
182
+ t = np.linalg.inv(t)
183
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
184
+ new_pt = np.dot(t, new_pt)
185
+ if as_int:
186
+ new_pt = new_pt.astype(int)
187
+ return new_pt[:2] + 1
188
+
189
+
190
+ def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
191
+ c_x = (ul[0] + br[0]) / 2
192
+ c_y = (ul[1] + br[1]) / 2
193
+ bb_width = patch_width = br[0] - ul[0]
194
+ bb_height = patch_height = br[1] - ul[1]
195
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
196
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
197
+ flags=cv2.INTER_LINEAR,
198
+ borderMode=border_mode,
199
+ borderValue=border_value
200
+ )
201
+
202
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
203
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
204
+ img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
205
+ flags=cv2.INTER_LINEAR,
206
+ borderMode=cv2.BORDER_CONSTANT,
207
+ )
208
+
209
+ return img_patch
210
+
211
+
212
+ def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
213
+ bb_width: float, bb_height: float,
214
+ patch_width: float, patch_height: float,
215
+ do_flip: bool, scale: float, rot: float,
216
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
217
+ """
218
+ Crop image according to the supplied bounding box.
219
+ Args:
220
+ img (np.array): Input image of shape (H, W, 3)
221
+ c_x (float): Bounding box center x coordinate in the original image.
222
+ c_y (float): Bounding box center y coordinate in the original image.
223
+ bb_width (float): Bounding box width.
224
+ bb_height (float): Bounding box height.
225
+ patch_width (float): Output box width.
226
+ patch_height (float): Output box height.
227
+ do_flip (bool): Whether to flip image or not.
228
+ scale (float): Rescaling factor for the bounding box (augmentation).
229
+ rot (float): Random rotation applied to the box.
230
+ Returns:
231
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
232
+ trans (np.array): Transformation matrix.
233
+ """
234
+
235
+ img_height, img_width, img_channels = img.shape
236
+ if do_flip:
237
+ img = img[:, ::-1, :]
238
+ c_x = img_width - c_x - 1
239
+
240
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
241
+
242
+ # img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
243
+
244
+ # skimage
245
+ center = np.zeros(2)
246
+ center[0] = c_x
247
+ center[1] = c_y
248
+ res = np.zeros(2)
249
+ res[0] = patch_width
250
+ res[1] = patch_height
251
+ # assumes bb_width = bb_height
252
+ # assumes patch_width = patch_height
253
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
254
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
255
+ scale1 = scale * bb_width / 200.
256
+
257
+ # Upper left point
258
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
259
+ # Bottom right point
260
+ br = np.array(transform([res[0] + 1,
261
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
262
+
263
+ # Padding so that when rotated proper amount of context is included
264
+ try:
265
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
266
+ except:
267
+ breakpoint()
268
+ if not rot == 0:
269
+ ul -= pad
270
+ br += pad
271
+
272
+ if False:
273
+ # Old way of cropping image
274
+ ul_int = ul.astype(int)
275
+ br_int = br.astype(int)
276
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
277
+ if len(img.shape) > 2:
278
+ new_shape += [img.shape[2]]
279
+ new_img = np.zeros(new_shape)
280
+
281
+ # Range to fill new array
282
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
283
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
284
+ # Range to sample from original image
285
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
286
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
287
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
288
+ old_x[0]:old_x[1]]
289
+
290
+ # New way of cropping image
291
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
292
+
293
+ # print(f'{new_img.shape=}')
294
+ # print(f'{new_img1.shape=}')
295
+ # print(f'{np.allclose(new_img, new_img1)=}')
296
+ # print(f'{img.dtype=}')
297
+
298
+ if not rot == 0:
299
+ # Remove padding
300
+
301
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
302
+ new_img = new_img[pad:-pad, pad:-pad]
303
+
304
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
305
+ print(f'{img.shape=}')
306
+ print(f'{new_img.shape=}')
307
+ print(f'{ul=}')
308
+ print(f'{br=}')
309
+ print(f'{pad=}')
310
+ print(f'{rot=}')
311
+
312
+ breakpoint()
313
+
314
+ # resize image
315
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
316
+
317
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
318
+
319
+ return new_img, trans
320
+
321
+
322
+ def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
323
+ bb_width: float, bb_height: float,
324
+ patch_width: float, patch_height: float,
325
+ do_flip: bool, scale: float, rot: float,
326
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
327
+ """
328
+ Crop the input image and return the crop and the corresponding transformation matrix.
329
+ Args:
330
+ img (np.array): Input image of shape (H, W, 3)
331
+ c_x (float): Bounding box center x coordinate in the original image.
332
+ c_y (float): Bounding box center y coordinate in the original image.
333
+ bb_width (float): Bounding box width.
334
+ bb_height (float): Bounding box height.
335
+ patch_width (float): Output box width.
336
+ patch_height (float): Output box height.
337
+ do_flip (bool): Whether to flip image or not.
338
+ scale (float): Rescaling factor for the bounding box (augmentation).
339
+ rot (float): Random rotation applied to the box.
340
+ Returns:
341
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
342
+ trans (np.array): Transformation matrix.
343
+ """
344
+
345
+ img_height, img_width, img_channels = img.shape
346
+ if do_flip:
347
+ img = img[:, ::-1, :]
348
+ c_x = img_width - c_x - 1
349
+
350
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
351
+
352
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
353
+ flags=cv2.INTER_LINEAR,
354
+ borderMode=border_mode,
355
+ borderValue=border_value,
356
+ )
357
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
358
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
359
+ img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
360
+ flags=cv2.INTER_LINEAR,
361
+ borderMode=cv2.BORDER_CONSTANT,
362
+ )
363
+
364
+ is_border = np.all(img_patch[:, :, :-1] == border_value, axis=2) if img_patch.shape[2] == 4 else np.all(img_patch == 0, axis=2)
365
+ img_border_mask = ~is_border
366
+ return img_patch, trans, img_border_mask
367
+
368
+
369
+ def convert_cvimg_to_tensor(cvimg: np.array):
370
+ """
371
+ Convert image from HWC to CHW format.
372
+ Args:
373
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
374
+ Returns:
375
+ np.array: Output image of shape (3, H, W).
376
+ """
377
+ # from h,w,c(OpenCV) to c,h,w
378
+ img = cvimg.copy()
379
+ img = np.transpose(img, (2, 0, 1))
380
+ # from int to float
381
+ img = img.astype(np.float32)
382
+ return img
383
+
384
+
385
+ def fliplr_params(smal_params: Dict, has_smal_params: Dict) -> Tuple[Dict, Dict]:
386
+ """
387
+ Flip SMAL parameters when flipping the image.
388
+ Args:
389
+ smal_params (Dict): SMAL parameter annotations.
390
+ has_smal_params (Dict): Whether SMAL annotations are valid.
391
+ Returns:
392
+ Dict, Dict: Flipped SMAL parameters and valid flags.
393
+ """
394
+ global_orient = smal_params['global_orient'].copy()
395
+ pose = smal_params['pose'].copy()
396
+ betas = smal_params['betas'].copy()
397
+ translation = smal_params['translation'].copy()
398
+ has_global_orient = has_smal_params['global_orient'].copy()
399
+ has_pose = has_smal_params['pose'].copy()
400
+ has_betas = has_smal_params['betas'].copy()
401
+ has_translation = has_smal_params['translation'].copy()
402
+
403
+ global_orient[1::3] *= -1
404
+ global_orient[2::3] *= -1
405
+ pose[1::3] *= -1
406
+ pose[2::3] *= -1
407
+ translation[1::3] *= -1
408
+ translation[2::3] *= -1
409
+
410
+ smal_params = {'global_orient': global_orient.astype(np.float32),
411
+ 'pose': pose.astype(np.float32),
412
+ 'betas': betas.astype(np.float32),
413
+ 'translation': translation.astype(np.float32)
414
+ }
415
+
416
+ has_smal_params = {'global_orient': has_global_orient,
417
+ 'pose': has_pose,
418
+ 'betas': has_betas,
419
+ 'translation': has_translation
420
+ }
421
+
422
+ return smal_params, has_smal_params
423
+
424
+
425
+ def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
426
+ """
427
+ Flip 2D or 3D keypoints.
428
+ Args:
429
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
430
+ flip_permutation (List): Permutation to apply after flipping.
431
+ Returns:
432
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
433
+ """
434
+ joints = joints.copy()
435
+ # Flip horizontal
436
+ joints[:, 0] = width - joints[:, 0] - 1
437
+ joints = joints[flip_permutation, :]
438
+
439
+ return joints
440
+
441
+
442
+ def keypoint_3d_processing(keypoints_3d: np.array, rot: float, filp: bool) -> np.array:
443
+ """
444
+ Process 3D keypoints (rotation/flipping).
445
+ Args:
446
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
447
+ rot (float): Random rotation applied to the keypoints.
448
+ Returns:
449
+ np.array: Transformed 3D keypoints with shape (N, 4).
450
+ """
451
+ # in-plane rotation
452
+ rot_mat = np.eye(3, dtype=np.float32)
453
+ if not rot == 0:
454
+ rot_rad = -rot * np.pi / 180
455
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
456
+ rot_mat[0, :2] = [cs, -sn]
457
+ rot_mat[1, :2] = [sn, cs]
458
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
459
+ # flip the x coordinates
460
+ if filp:
461
+ keypoints_3d = fliplr_keypoints(keypoints_3d, list(range(len(keypoints_3d))))
462
+ keypoints_3d = keypoints_3d.astype('float32')
463
+ return keypoints_3d
464
+
465
+
466
+ def rot_aa(aa: np.array, rot: float) -> np.array:
467
+ """
468
+ Rotate axis angle parameters.
469
+ Args:
470
+ aa (np.array): Axis-angle vector of shape (3,).
471
+ rot (np.array): Rotation angle in degrees.
472
+ Returns:
473
+ np.array: Rotated axis-angle vector.
474
+ """
475
+ # pose parameters
476
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
477
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
478
+ [0, 0, 1]])
479
+ # find the rotation of the hand in camera frame
480
+ per_rdg, _ = cv2.Rodrigues(aa)
481
+ # apply the global rotation to the global orientation
482
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
483
+ aa = (resrot.T)[0]
484
+ return aa.astype(np.float32)
485
+
486
+
487
+ def smal_param_processing(smal_params: Dict, has_smal_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
488
+ """
489
+ Apply random augmentations to the SMAL parameters.
490
+ Args:
491
+ smal_params (Dict): SMAL parameter annotations.
492
+ has_smal_params (Dict): Whether SMAL annotations are valid.
493
+ rot (float): Random rotation applied to the keypoints.
494
+ do_flip (bool): Whether to flip keypoints or not.
495
+ Returns:
496
+ Dict, Dict: Transformed SMAL parameters and valid flags.
497
+ """
498
+ if do_flip:
499
+ smal_params, has_smal_params = fliplr_params(smal_params, has_smal_params)
500
+ smal_params['global_orient'] = rot_aa(smal_params['global_orient'], rot)
501
+ # camera location is not change, so the translation is not change too.
502
+ # smal_params['transl'] = np.dot(np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
503
+ # [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
504
+ # [0, 0, 1]], dtype=np.float32), smal_params['transl'])
505
+ return smal_params, has_smal_params
506
+
507
+
508
+ def get_example(img_path: Union[str,np.ndarray], center_x: float, center_y: float,
509
+ width: float, height: float,
510
+ keypoints_2d: np.array, keypoints_3d: np.array,
511
+ smal_params: Dict, has_smal_params: Dict,
512
+ patch_width: int, patch_height: int,
513
+ mean: np.array, std: np.array,
514
+ do_augment: bool, augm_config: CfgNode,
515
+ is_bgr: bool = True,
516
+ use_skimage_antialias: bool = False,
517
+ border_mode: int = cv2.BORDER_CONSTANT,
518
+ return_trans: bool = False,) -> Tuple:
519
+ """
520
+ Get an example from the dataset and (possibly) apply random augmentations.
521
+ Args:
522
+ img_path (str): Image filename
523
+ center_x (float): Bounding box center x coordinate in the original image.
524
+ center_y (float): Bounding box center y coordinate in the original image.
525
+ width (float): Bounding box width.
526
+ height (float): Bounding box height.
527
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
528
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
529
+ smal_params (Dict): SMAL parameter annotations.
530
+ has_smal_params (Dict): Whether SMAL annotations are valid.
531
+ patch_width (float): Output box width.
532
+ patch_height (float): Output box height.
533
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
534
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
535
+ do_augment (bool): Whether to apply data augmentation or not.
536
+ aug_config (CfgNode): Config containing augmentation parameters.
537
+ Returns:
538
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size
539
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
540
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
541
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
542
+ smal_params (Dict): Transformed SMAL parameters.
543
+ has_smal_params (Dict): Valid flag for transformed SMAL parameters.
544
+ img_size (np.array): Image size of the original image.
545
+ """
546
+ if isinstance(img_path, str):
547
+ # 1. load image
548
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
549
+ if not isinstance(cvimg, np.ndarray):
550
+ raise IOError("Fail to read %s" % img_path)
551
+ elif isinstance(img_path, np.ndarray):
552
+ cvimg = img_path
553
+ else:
554
+ raise TypeError('img_path must be either a string or a numpy array')
555
+ img_height, img_width, img_channels = cvimg.shape
556
+
557
+ img_size = np.array([img_height, img_width], dtype=np.int32)
558
+
559
+ # 2. get augmentation params
560
+ if do_augment:
561
+ # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
562
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
563
+ else:
564
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
565
+ 1.0,
566
+ 1.0], 0., 0.
567
+ if width < 1 or height < 1:
568
+ breakpoint()
569
+
570
+ if do_extreme_crop:
571
+ if extreme_crop_lvl == 0:
572
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
573
+ elif extreme_crop_lvl == 1:
574
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height,
575
+ keypoints_2d)
576
+
577
+ THRESH = 4
578
+ if width1 < THRESH or height1 < THRESH:
579
+ pass
580
+ else:
581
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
582
+
583
+ center_x += width * tx
584
+ center_y += height * ty
585
+
586
+ # Process 3D keypoints
587
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, rot, do_flip)
588
+
589
+ # 3. generate image patch
590
+ if use_skimage_antialias:
591
+ # Blur image to avoid aliasing artifacts
592
+ downsampling_factor = (patch_width / (width * scale))
593
+ if downsampling_factor > 1.1:
594
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True,
595
+ truncate=3.0)
596
+ # augmentation image, translation matrix
597
+ img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
598
+ center_x, center_y,
599
+ width, height,
600
+ patch_width, patch_height,
601
+ do_flip, scale, rot,
602
+ border_mode=border_mode)
603
+
604
+ image = img_patch_cv.copy()
605
+ if is_bgr:
606
+ image = image[:, :, ::-1]
607
+ img_patch_cv = image.copy()
608
+ img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
609
+
610
+ smal_params, has_smal_params = smal_param_processing(smal_params, has_smal_params, rot, do_flip)
611
+
612
+ # apply normalization
613
+ for n_c in range(min(img_channels, 3)):
614
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
615
+ if mean is not None and std is not None:
616
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
617
+
618
+ if do_flip:
619
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
620
+
621
+ for n_jt in range(len(keypoints_2d)):
622
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
623
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
624
+
625
+ if not return_trans:
626
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, img_border_mask
627
+ else:
628
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask
629
+
630
+
631
+ def get_cub17_example(cvimg: np.array,
632
+ keypoints_2d: np.array,
633
+ center_x: float, center_y: float,
634
+ width: float, height: float,
635
+ patch_width: int, patch_height: int,
636
+ mean: np.array, std: np.array,
637
+ do_augment: bool, augm_config: CfgNode,
638
+ return_trans=True) -> Tuple:
639
+ """
640
+ Get an example from the dataset and (possibly) apply random augmentations.
641
+ Args:
642
+ cvimg (np.ndarray): Image
643
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
644
+ center_x (float): Bounding box center x coordinate in the original image.
645
+ center_y (float): Bounding box center y coordinate in the original image.
646
+ width (float): Bounding box width.
647
+ height (float): Bounding box height.
648
+ patch_width (int): Output box width.
649
+ patch_height (int): Output box height.
650
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
651
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
652
+ do_augment (bool): Whether to apply data augmentation or not.
653
+ aug_config (CfgNode): Config containing augmentation parameters.
654
+ Returns:
655
+ return img_patch, keypoints_2d
656
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
657
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
658
+ """
659
+ img_height, img_width, img_channels = cvimg.shape
660
+
661
+ img_size = np.array([img_height, img_width], dtype=np.int32)
662
+
663
+ # 2. get augmentation params
664
+ if do_augment:
665
+ # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
666
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
667
+ else:
668
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
669
+ 1.0,
670
+ 1.0], 0., 0.
671
+ # bounding box height and width
672
+ center_x += width * tx
673
+ center_y += height * ty
674
+ # augmentation image, translation matrix
675
+ img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
676
+ center_x, center_y,
677
+ width, height,
678
+ patch_width, patch_height,
679
+ do_flip, scale, rot,
680
+ border_mode=cv2.BORDER_CONSTANT)
681
+
682
+ image = img_patch_cv.copy()
683
+ img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
684
+
685
+ # apply normalization
686
+ for n_c in range(min(img_channels, 3)):
687
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
688
+ if mean is not None and std is not None:
689
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
690
+
691
+ if do_flip:
692
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
693
+
694
+ for n_jt in range(len(keypoints_2d)):
695
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
696
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
697
+
698
+ if return_trans:
699
+ return img_patch, keypoints_2d, img_size, trans, img_border_mask
700
+ else:
701
+ return img_patch, keypoints_2d, img_size, img_border_mask
702
+
703
+
704
+ def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
705
+ """
706
+ Extreme cropping: Crop the box up to the hip locations.
707
+ Args:
708
+ center_x (float): x coordinate of the bounding box center.
709
+ center_y (float): y coordinate of the bounding box center.
710
+ width (float): Bounding box width.
711
+ height (float): Bounding box height.
712
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
713
+ Returns:
714
+ center_x (float): x coordinate of the new bounding box center.
715
+ center_y (float): y coordinate of the new bounding box center.
716
+ width (float): New bounding box width.
717
+ height (float): New bounding box height.
718
+ """
719
+ keypoints_2d = keypoints_2d.copy()
720
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25 + 0, 25 + 1, 25 + 4, 25 + 5]
721
+ keypoints_2d[lower_body_keypoints, :] = 0
722
+ if keypoints_2d[:, -1].sum() > 1:
723
+ center, scale = get_bbox(keypoints_2d)
724
+ center_x = center[0]
725
+ center_y = center[1]
726
+ width = 1.1 * scale[0]
727
+ height = 1.1 * scale[1]
728
+ return center_x, center_y, width, height
729
+
730
+
731
+ def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
732
+ """
733
+ Extreme cropping: Crop the box up to the shoulder locations.
734
+ Args:
735
+ center_x (float): x coordinate of the bounding box center.
736
+ center_y (float): y coordinate of the bounding box center.
737
+ width (float): Bounding box width.
738
+ height (float): Bounding box height.
739
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
740
+ Returns:
741
+ center_x (float): x coordinate of the new bounding box center.
742
+ center_y (float): y coordinate of the new bounding box center.
743
+ width (float): New bounding box width.
744
+ height (float): New bounding box height.
745
+ """
746
+ keypoints_2d = keypoints_2d.copy()
747
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
748
+ [0, 1, 2, 3, 4, 5, 6, 7,
749
+ 10, 11, 14, 15, 16]]
750
+ keypoints_2d[lower_body_keypoints, :] = 0
751
+ center, scale = get_bbox(keypoints_2d)
752
+ if keypoints_2d[:, -1].sum() > 1:
753
+ center, scale = get_bbox(keypoints_2d)
754
+ center_x = center[0]
755
+ center_y = center[1]
756
+ width = 1.2 * scale[0]
757
+ height = 1.2 * scale[1]
758
+ return center_x, center_y, width, height
759
+
760
+
761
+ def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
762
+ """
763
+ Extreme cropping: Crop the box and keep on only the head.
764
+ Args:
765
+ center_x (float): x coordinate of the bounding box center.
766
+ center_y (float): y coordinate of the bounding box center.
767
+ width (float): Bounding box width.
768
+ height (float): Bounding box height.
769
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
770
+ Returns:
771
+ center_x (float): x coordinate of the new bounding box center.
772
+ center_y (float): y coordinate of the new bounding box center.
773
+ width (float): New bounding box width.
774
+ height (float): New bounding box height.
775
+ """
776
+ keypoints_2d = keypoints_2d.copy()
777
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
778
+ [0, 1, 2, 3, 4, 5, 6, 7, 8,
779
+ 9, 10, 11, 14, 15, 16]]
780
+ keypoints_2d[lower_body_keypoints, :] = 0
781
+ if keypoints_2d[:, -1].sum() > 1:
782
+ center, scale = get_bbox(keypoints_2d)
783
+ center_x = center[0]
784
+ center_y = center[1]
785
+ width = 1.3 * scale[0]
786
+ height = 1.3 * scale[1]
787
+ return center_x, center_y, width, height
788
+
789
+
790
+ def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
791
+ """
792
+ Extreme cropping: Crop the box and keep on only the torso.
793
+ Args:
794
+ center_x (float): x coordinate of the bounding box center.
795
+ center_y (float): y coordinate of the bounding box center.
796
+ width (float): Bounding box width.
797
+ height (float): Bounding box height.
798
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
799
+ Returns:
800
+ center_x (float): x coordinate of the new bounding box center.
801
+ center_y (float): y coordinate of the new bounding box center.
802
+ width (float): New bounding box width.
803
+ height (float): New bounding box height.
804
+ """
805
+ keypoints_2d = keypoints_2d.copy()
806
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in
807
+ [0, 1, 4, 5, 6,
808
+ 7, 10, 11, 13,
809
+ 17, 18]]
810
+ keypoints_2d[nontorso_body_keypoints, :] = 0
811
+ if keypoints_2d[:, -1].sum() > 1:
812
+ center, scale = get_bbox(keypoints_2d)
813
+ center_x = center[0]
814
+ center_y = center[1]
815
+ width = 1.1 * scale[0]
816
+ height = 1.1 * scale[1]
817
+ return center_x, center_y, width, height
818
+
819
+
820
+ def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
821
+ """
822
+ Extreme cropping: Crop the box and keep on only the right arm.
823
+ Args:
824
+ center_x (float): x coordinate of the bounding box center.
825
+ center_y (float): y coordinate of the bounding box center.
826
+ width (float): Bounding box width.
827
+ height (float): Bounding box height.
828
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
829
+ Returns:
830
+ center_x (float): x coordinate of the new bounding box center.
831
+ center_y (float): y coordinate of the new bounding box center.
832
+ width (float): New bounding box width.
833
+ height (float): New bounding box height.
834
+ """
835
+ keypoints_2d = keypoints_2d.copy()
836
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
837
+ 25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
838
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
839
+ if keypoints_2d[:, -1].sum() > 1:
840
+ center, scale = get_bbox(keypoints_2d)
841
+ center_x = center[0]
842
+ center_y = center[1]
843
+ width = 1.1 * scale[0]
844
+ height = 1.1 * scale[1]
845
+ return center_x, center_y, width, height
846
+
847
+
848
+ def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
849
+ """
850
+ Extreme cropping: Crop the box and keep on only the left arm.
851
+ Args:
852
+ center_x (float): x coordinate of the bounding box center.
853
+ center_y (float): y coordinate of the bounding box center.
854
+ width (float): Bounding box width.
855
+ height (float): Bounding box height.
856
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
857
+ Returns:
858
+ center_x (float): x coordinate of the new bounding box center.
859
+ center_y (float): y coordinate of the new bounding box center.
860
+ width (float): New bounding box width.
861
+ height (float): New bounding box height.
862
+ """
863
+ keypoints_2d = keypoints_2d.copy()
864
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
865
+ 25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
866
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
867
+ if keypoints_2d[:, -1].sum() > 1:
868
+ center, scale = get_bbox(keypoints_2d)
869
+ center_x = center[0]
870
+ center_y = center[1]
871
+ width = 1.1 * scale[0]
872
+ height = 1.1 * scale[1]
873
+ return center_x, center_y, width, height
874
+
875
+
876
+ def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
877
+ """
878
+ Extreme cropping: Crop the box and keep on only the legs.
879
+ Args:
880
+ center_x (float): x coordinate of the bounding box center.
881
+ center_y (float): y coordinate of the bounding box center.
882
+ width (float): Bounding box width.
883
+ height (float): Bounding box height.
884
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
885
+ Returns:
886
+ center_x (float): x coordinate of the new bounding box center.
887
+ center_y (float): y coordinate of the new bounding box center.
888
+ width (float): New bounding box width.
889
+ height (float): New bounding box height.
890
+ """
891
+ keypoints_2d = keypoints_2d.copy()
892
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in
893
+ [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
894
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
895
+ if keypoints_2d[:, -1].sum() > 1:
896
+ center, scale = get_bbox(keypoints_2d)
897
+ center_x = center[0]
898
+ center_y = center[1]
899
+ width = 1.1 * scale[0]
900
+ height = 1.1 * scale[1]
901
+ return center_x, center_y, width, height
902
+
903
+
904
+ def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
905
+ """
906
+ Extreme cropping: Crop the box and keep on only the right leg.
907
+ Args:
908
+ center_x (float): x coordinate of the bounding box center.
909
+ center_y (float): y coordinate of the bounding box center.
910
+ width (float): Bounding box width.
911
+ height (float): Bounding box height.
912
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
913
+ Returns:
914
+ center_x (float): x coordinate of the new bounding box center.
915
+ center_y (float): y coordinate of the new bounding box center.
916
+ width (float): New bounding box width.
917
+ height (float): New bounding box height.
918
+ """
919
+ keypoints_2d = keypoints_2d.copy()
920
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in
921
+ [3, 4, 5, 6, 7,
922
+ 8, 9, 10, 11,
923
+ 12, 13, 14, 15,
924
+ 16, 17, 18]]
925
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
926
+ if keypoints_2d[:, -1].sum() > 1:
927
+ center, scale = get_bbox(keypoints_2d)
928
+ center_x = center[0]
929
+ center_y = center[1]
930
+ width = 1.1 * scale[0]
931
+ height = 1.1 * scale[1]
932
+ return center_x, center_y, width, height
933
+
934
+
935
+ def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
936
+ """
937
+ Extreme cropping: Crop the box and keep on only the left leg.
938
+ Args:
939
+ center_x (float): x coordinate of the bounding box center.
940
+ center_y (float): y coordinate of the bounding box center.
941
+ width (float): Bounding box width.
942
+ height (float): Bounding box height.
943
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
944
+ Returns:
945
+ center_x (float): x coordinate of the new bounding box center.
946
+ center_y (float): y coordinate of the new bounding box center.
947
+ width (float): New bounding box width.
948
+ height (float): New bounding box height.
949
+ """
950
+ keypoints_2d = keypoints_2d.copy()
951
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in
952
+ [0, 1, 2, 6, 7, 8,
953
+ 9, 10, 11, 12,
954
+ 13, 14, 15, 16,
955
+ 17, 18]]
956
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
957
+ if keypoints_2d[:, -1].sum() > 1:
958
+ center, scale = get_bbox(keypoints_2d)
959
+ center_x = center[0]
960
+ center_y = center[1]
961
+ width = 1.1 * scale[0]
962
+ height = 1.1 * scale[1]
963
+ return center_x, center_y, width, height
964
+
965
+
966
+ def full_body(keypoints_2d: np.array) -> bool:
967
+ """
968
+ Check if all main body joints are visible.
969
+ Args:
970
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
971
+ Returns:
972
+ bool: True if all main body joints are visible.
973
+ """
974
+
975
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
976
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
977
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(
978
+ body_keypoints)
979
+
980
+
981
+ def upper_body(keypoints_2d: np.array):
982
+ """
983
+ Check if all upper body joints are visible.
984
+ Args:
985
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
986
+ Returns:
987
+ bool: True if all main body joints are visible.
988
+ """
989
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
990
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
991
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
992
+ upper_body_keypoints = [25 + 8, 25 + 9, 25 + 12, 25 + 13, 25 + 17, 25 + 18]
993
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0) \
994
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
995
+
996
+
997
+ def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
998
+ """
999
+ Get center and scale for bounding box from openpose detections.
1000
+ Args:
1001
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1002
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1003
+ Returns:
1004
+ center (np.array): Array of shape (2,) containing the new bounding box center.
1005
+ scale (float): New bounding box scale.
1006
+ """
1007
+ valid = keypoints_2d[:, -1] > 0
1008
+ valid_keypoints = keypoints_2d[valid][:, :-1]
1009
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
1010
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
1011
+ # adjust bounding box tightness
1012
+ scale = bbox_size
1013
+ scale *= rescale
1014
+ return center, scale
1015
+
1016
+
1017
+ def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
1018
+ """
1019
+ Perform extreme cropping
1020
+ Args:
1021
+ center_x (float): x coordinate of bounding box center.
1022
+ center_y (float): y coordinate of bounding box center.
1023
+ width (float): bounding box width.
1024
+ height (float): bounding box height.
1025
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1026
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1027
+ Returns:
1028
+ center_x (float): x coordinate of bounding box center.
1029
+ center_y (float): y coordinate of bounding box center.
1030
+ width (float): bounding box width.
1031
+ height (float): bounding box height.
1032
+ """
1033
+ p = torch.rand(1).item()
1034
+ if full_body(keypoints_2d):
1035
+ if p < 0.7:
1036
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
1037
+ elif p < 0.9:
1038
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1039
+ else:
1040
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1041
+ elif upper_body(keypoints_2d):
1042
+ if p < 0.9:
1043
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1044
+ else:
1045
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1046
+
1047
+ return center_x, center_y, max(width, height), max(width, height)
1048
+
1049
+
1050
+ def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float,
1051
+ keypoints_2d: np.array) -> Tuple:
1052
+ """
1053
+ Perform aggressive extreme cropping
1054
+ Args:
1055
+ center_x (float): x coordinate of bounding box center.
1056
+ center_y (float): y coordinate of bounding box center.
1057
+ width (float): bounding box width.
1058
+ height (float): bounding box height.
1059
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1060
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1061
+ Returns:
1062
+ center_x (float): x coordinate of bounding box center.
1063
+ center_y (float): y coordinate of bounding box center.
1064
+ width (float): bounding box width.
1065
+ height (float): bounding box height.
1066
+ """
1067
+ p = torch.rand(1).item()
1068
+ if full_body(keypoints_2d):
1069
+ if p < 0.2:
1070
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
1071
+ elif p < 0.3:
1072
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1073
+ elif p < 0.4:
1074
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1075
+ elif p < 0.5:
1076
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
1077
+ elif p < 0.6:
1078
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
1079
+ elif p < 0.7:
1080
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
1081
+ elif p < 0.8:
1082
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
1083
+ elif p < 0.9:
1084
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
1085
+ else:
1086
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
1087
+ elif upper_body(keypoints_2d):
1088
+ if p < 0.2:
1089
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1090
+ elif p < 0.4:
1091
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1092
+ elif p < 0.6:
1093
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
1094
+ elif p < 0.8:
1095
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
1096
+ else:
1097
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
1098
+ return center_x, center_y, max(width, height), max(width, height)
amr/datasets/vitdet_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.filters import gaussian
6
+ from yacs.config import CfgNode
7
+ import torch
8
+
9
+ from .utils import (convert_cvimg_to_tensor,
10
+ expand_to_aspect_ratio,
11
+ generate_image_patch_cv2)
12
+
13
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
14
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
15
+
16
+
17
+ class ViTDetDataset(torch.utils.data.Dataset):
18
+
19
+ def __init__(self,
20
+ cfg: CfgNode,
21
+ img_cv2: np.array,
22
+ boxes: np.array,
23
+ category: List[int],
24
+ rescale_factor=1,
25
+ train: bool = False,
26
+ **kwargs):
27
+ super().__init__()
28
+ self.cfg = cfg
29
+ self.img_cv2 = img_cv2
30
+ self.boxes = boxes
31
+ self.category = category
32
+ self.focal_length = []
33
+ for c in category:
34
+ if c == 6:
35
+ self.focal_length.append([cfg.AVES.FOCAL_LENGTH, cfg.AVES.FOCAL_LENGTH])
36
+ else:
37
+ self.focal_length.append([cfg.SMAL.FOCAL_LENGTH, cfg.SMAL.FOCAL_LENGTH])
38
+
39
+ assert train is False, "ViTDetDataset is only for inference"
40
+ self.train = train
41
+ self.img_size = cfg.MODEL.IMAGE_SIZE
42
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
43
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
44
+
45
+ # Preprocess annotations
46
+ boxes = boxes.astype(np.float32)
47
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
48
+ self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
49
+ self.animalid = np.arange(len(boxes), dtype=np.int32)
50
+
51
+ def __len__(self) -> int:
52
+ return len(self.animalid)
53
+
54
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
55
+
56
+ center = self.center[idx].copy()
57
+ center_x = center[0]
58
+ center_y = center[1]
59
+
60
+ scale = self.scale[idx]
61
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
62
+ bbox_size = expand_to_aspect_ratio(scale * 200, target_aspect_ratio=BBOX_SHAPE).max()
63
+
64
+ patch_width = patch_height = self.img_size
65
+
66
+ flip = False
67
+
68
+ # 3. generate image patch
69
+ # if use_skimage_antialias:
70
+ cvimg = self.img_cv2.copy()
71
+ if True:
72
+ # Blur image to avoid aliasing artifacts
73
+ downsampling_factor = ((bbox_size * 1.0) / patch_width)
74
+ print(f'{downsampling_factor=}')
75
+ downsampling_factor = downsampling_factor / 2.0
76
+ if downsampling_factor > 1.1:
77
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True)
78
+
79
+ img_patch_cv, trans, _ = generate_image_patch_cv2(cvimg,
80
+ center_x, center_y,
81
+ bbox_size, bbox_size,
82
+ patch_width, patch_height,
83
+ flip, 1.0, 0.0,
84
+ border_mode=cv2.BORDER_CONSTANT)
85
+ img_patch_cv = img_patch_cv[:, :, ::-1]
86
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
87
+
88
+ # apply normalization
89
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
90
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
91
+
92
+ item = {
93
+ 'img': img_patch,
94
+ 'animalid': int(self.animalid[idx]),
95
+ 'box_center': self.center[idx].copy(),
96
+ 'box_size': bbox_size,
97
+ 'img_size': 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]),
98
+ 'focal_length': np.array(self.focal_length, dtype=np.float32)[idx],
99
+ 'supercategory': self.category[idx],
100
+ 'has_mask': np.array(0., dtype=np.float32)
101
+ }
102
+ return item
amr/models/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .smal_warapper import SMAL
2
+ from .aves_warapper import AVES
3
+ from .animerpp import AniMerPlusPlus
4
+
5
+
6
+ def load_amr(checkpoint_path):
7
+ from pathlib import Path
8
+ from ..configs import get_config
9
+ model_cfg = str(Path(checkpoint_path).parent / 'config.yaml')
10
+ model_cfg = get_config(model_cfg, update_cachedir=True)
11
+
12
+ # Override some config values, to crop bbox correctly
13
+ if (model_cfg.MODEL.BACKBONE.TYPE in ['vith', 'vithmoe']) and ('BBOX_SHAPE' not in model_cfg.MODEL):
14
+ model_cfg.defrost()
15
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
16
+ model_cfg.MODEL.BBOX_SHAPE = [192, 256]
17
+ model_cfg.freeze()
18
+
19
+ # Update config to be compatible with demo
20
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
21
+ model_cfg.defrost()
22
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
23
+ model_cfg.freeze()
24
+
25
+ model = AniMerPlusPlus.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg, map_location="cpu")
26
+ return model, model_cfg
amr/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (999 Bytes). View file
 
amr/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.87 kB). View file
 
amr/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.01 kB). View file
 
amr/models/__pycache__/amr.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
amr/models/__pycache__/amr.cpython-312.pyc ADDED
Binary file (26.7 kB). View file
 
amr/models/__pycache__/amr.cpython-39.pyc ADDED
Binary file (13.8 kB). View file
 
amr/models/__pycache__/animerpp.cpython-310.pyc ADDED
Binary file (16.8 kB). View file
 
amr/models/__pycache__/animerpp.cpython-312.pyc ADDED
Binary file (39.2 kB). View file
 
amr/models/__pycache__/aves_hmr.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
amr/models/__pycache__/aves_hmr.cpython-312.pyc ADDED
Binary file (31 kB). View file
 
amr/models/__pycache__/aves_warapper.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
amr/models/__pycache__/aves_warapper.cpython-312.pyc ADDED
Binary file (8.42 kB). View file
 
amr/models/__pycache__/discriminator.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
amr/models/__pycache__/discriminator.cpython-312.pyc ADDED
Binary file (7.94 kB). View file
 
amr/models/__pycache__/discriminator.cpython-39.pyc ADDED
Binary file (2.6 kB). View file
 
amr/models/__pycache__/dyamr.cpython-310.pyc ADDED
Binary file (18.2 kB). View file
 
amr/models/__pycache__/dyamr.cpython-312.pyc ADDED
Binary file (34.3 kB). View file
 
amr/models/__pycache__/losses.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
amr/models/__pycache__/losses.cpython-312.pyc ADDED
Binary file (24.4 kB). View file
 
amr/models/__pycache__/losses.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
amr/models/__pycache__/predictor.cpython-310.pyc ADDED
Binary file (7.3 kB). View file
 
amr/models/__pycache__/smal_warapper.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
amr/models/__pycache__/smal_warapper.cpython-312.pyc ADDED
Binary file (8.66 kB). View file
 
amr/models/__pycache__/smal_warapper.cpython-39.pyc ADDED
Binary file (5.53 kB). View file
 
amr/models/__pycache__/smooth_amr.cpython-310.pyc ADDED
Binary file (7.19 kB). View file
 
amr/models/__pycache__/smooth_amr.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
amr/models/__pycache__/smooth_netv2.cpython-310.pyc ADDED
Binary file (9.97 kB). View file
 
amr/models/__pycache__/stamr.cpython-310.pyc ADDED
Binary file (6.27 kB). View file
 
amr/models/__pycache__/stamr.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
amr/models/animerpp.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import pytorch_lightning as pl
4
+ from torchvision.utils import make_grid
5
+ from typing import Dict
6
+ from pytorch3d.transforms import matrix_to_axis_angle
7
+ from yacs.config import CfgNode
8
+ from ..utils import MeshRenderer
9
+ from ..utils.geometry import aa_to_rotmat, perspective_projection
10
+ from ..utils.pylogger import get_pylogger
11
+ from ..utils.mesh_renderer import SilhouetteRenderer
12
+ from .backbones import create_backbone
13
+ from .heads.classifier_head import ClassTokenHead
14
+ from .heads import build_aves_head, build_smal_head
15
+ from .losses import (Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, SupConLoss,
16
+ PoseBonePriorLoss, SilhouetteLoss, ShapePriorLoss, PosePriorLoss)
17
+ from .aves_warapper import AVES
18
+ from .smal_warapper import SMAL
19
+
20
+
21
+ log = get_pylogger(__name__)
22
+
23
+
24
+ class AniMerPlusPlus(pl.LightningModule):
25
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
26
+ """
27
+ Setup AVES-HMR model
28
+ Args:
29
+ cfg (CfgNode): Config file as a yacs CfgNode
30
+ """
31
+ super().__init__()
32
+
33
+ # Save hyperparameters
34
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
35
+
36
+ self.cfg = cfg
37
+ # Create backbone feature extractor
38
+ self.backbone = create_backbone(cfg)
39
+
40
+ # Create AVES head
41
+ self.aves_head = build_aves_head(cfg)
42
+
43
+ # Create SMAL head
44
+ self.smal_head = build_smal_head(cfg)
45
+
46
+ self.class_token_head = ClassTokenHead(**cfg.MODEL.get("CLASS_TOKEN_HEAD", dict()))
47
+
48
+ # Define loss functions
49
+ # common loss
50
+ self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
51
+ self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
52
+ self.supcon_loss = SupConLoss()
53
+ self.parameter_loss = ParameterLoss()
54
+ # aves loss
55
+ self.posebone_prior_loss = PoseBonePriorLoss(path_prior=cfg.AVES.POSE_PRIOR_PATH)
56
+ self.mask_loss = SilhouetteLoss()
57
+ # smal loss
58
+ self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH)
59
+ self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH)
60
+ # Instantiate AVES model
61
+ aves_model_path = cfg.AVES.MODEL_PATH
62
+ aves_cfg = torch.load(aves_model_path, weights_only=True)
63
+ self.aves = AVES(**aves_cfg)
64
+
65
+ # Instantiate SMAL model
66
+ smal_model_path = cfg.SMAL.MODEL_PATH
67
+ with open(smal_model_path, 'rb') as f:
68
+ smal_cfg = pickle.load(f, encoding="latin1")
69
+ self.smal = SMAL(**smal_cfg)
70
+
71
+ # Buffer that shows whetheer we need to initialize ActNorm layers
72
+ self.register_buffer('initialized', torch.tensor(False))
73
+ # Setup renderer for visualization
74
+ if init_renderer:
75
+ self.aves_mesh_renderer = MeshRenderer(self.cfg, faces=aves_cfg['F'].numpy())
76
+ self.smal_mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy())
77
+ else:
78
+ self.renderer = None
79
+ self.mesh_renderer = None
80
+
81
+ # Only appling for AVES training
82
+ self.aves_silouette_render = SilhouetteRenderer(size=self.cfg.MODEL.IMAGE_SIZE,
83
+ focal=self.cfg.AVES.get("FOCAL_LENGTH", 2167),
84
+ device='cuda')
85
+
86
+ self.automatic_optimization = False
87
+
88
+ def get_parameters(self):
89
+ all_params = list(self.aves_head.parameters())
90
+ all_params += list(self.backbone.parameters())
91
+ all_params += list(self.smal_head.parameters())
92
+ all_params += list(self.class_token_head.parameters())
93
+ return all_params
94
+
95
+ def configure_optimizers(self):
96
+ """
97
+ Setup model and distriminator Optimizers
98
+ Returns:
99
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
100
+ """
101
+ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
102
+
103
+ if "vit" in self.cfg.MODEL.BACKBONE.TYPE:
104
+ optimizer = torch.optim.AdamW(params=param_groups,
105
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
106
+ else:
107
+ optimizer = torch.optim.Adam(params=param_groups,
108
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
109
+ return optimizer
110
+
111
+ def forward_backbone(self, batch: Dict):
112
+ x = batch['img']
113
+ dataset_source = batch["supercategory"] < 5 # bird for index 0
114
+ # Compute conditioning features using the backbone
115
+ if self.cfg.MODEL.BACKBONE.TYPE in ["vith"]:
116
+ conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # [256, 192]
117
+ elif self.cfg.MODEL.BACKBONE.TYPE in ["vithmoe"]:
118
+ conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32], dataset_source=dataset_source.type(torch.long))
119
+ else:
120
+ conditioning_feats = self.backbone(x)
121
+ cls = None
122
+ return conditioning_feats, cls
123
+
124
+ def forward_one_parametric_model(self,
125
+ focal_length: torch.tensor,
126
+ features: torch.tensor,
127
+ head: torch.nn.Module,
128
+ parametric_model: torch.nn.Module,):
129
+ """
130
+ Run a forward step of one parametric model.
131
+ Args:
132
+ batch (Dict): Dictionary containing batch data
133
+ Returns:
134
+ Dict: Dictionary containing the regression output
135
+ """
136
+ batch_size = features.shape[0]
137
+ pred_params, pred_cam, _ = head(features)
138
+ # Store useful regression outputs to the output dict
139
+ output = {}
140
+ output['pred_cam'] = pred_cam
141
+ output['pred_params'] = {k: v.clone() for k, v in pred_params.items()}
142
+
143
+ # Compute camera translation
144
+ pred_cam_t = torch.stack([pred_cam[:, 1],
145
+ pred_cam[:, 2],
146
+ 2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)], dim=-1)
147
+ output['pred_cam_t'] = pred_cam_t
148
+ output['focal_length'] = focal_length
149
+
150
+ # Compute model vertices, joints and the projected joints
151
+ pred_params['global_orient'] = pred_params['global_orient'].reshape(batch_size, -1, 3, 3)
152
+ pred_params['pose'] = pred_params['pose'].reshape(batch_size, -1, 3, 3)
153
+ pred_params['betas'] = pred_params['betas'].reshape(batch_size, -1)
154
+ pred_params['bone'] = pred_params['bone'].reshape(batch_size, -1) if 'bone' in pred_params else None
155
+ parametric_model_output = parametric_model(**pred_params, pose2rot=False)
156
+
157
+ pred_keypoints_3d = parametric_model_output.joints
158
+ pred_vertices = parametric_model_output.vertices
159
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
160
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
161
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
162
+ focal_length = focal_length.reshape(-1, 2)
163
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
164
+ translation=pred_cam_t,
165
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
166
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
167
+ return output
168
+
169
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
170
+ """
171
+ Run a forward step of the network
172
+ Args:
173
+ batch (Dict): Dictionary containing batch data
174
+ train (bool): Flag indicating whether it is training or validation mode
175
+ Returns:
176
+ Dict: Dictionary containing the regression output
177
+ """
178
+ # Use RGB image as input
179
+ x = batch['img']
180
+ batch_size = x.shape[0]
181
+ device = x.device
182
+ dataset_source = (batch["supercategory"] < 5) # bird for index 0
183
+
184
+ features, cls = self.forward_backbone(batch)
185
+
186
+ output = dict()
187
+ output['cls_feats'] = self.class_token_head(cls) if self.cfg.MODEL.BACKBONE.get("USE_CLS", False) else None
188
+
189
+ num_aves = (batch_size - dataset_source.sum()).item()
190
+ if num_aves:
191
+ output['aves_output'] = self.forward_one_parametric_model(batch['focal_length'][~dataset_source],
192
+ features[~dataset_source],
193
+ self.aves_head,
194
+ self.aves)
195
+ # Only specific to AVES training
196
+ output['aves_output']['pred_mask'] = self.aves_silouette_render(output['aves_output']['pred_vertices']+output['aves_output']['pred_cam_t'].unsqueeze(1),
197
+ faces=self.aves.face.unsqueeze(0).repeat(batch_size-dataset_source.sum().item(), 1, 1).to(device))
198
+
199
+ num_smal = dataset_source.sum().item()
200
+ if num_smal:
201
+ output['smal_output'] = self.forward_one_parametric_model(batch['focal_length'][dataset_source],
202
+ features[dataset_source],
203
+ self.smal_head,
204
+ self.smal)
205
+ return output
206
+
207
+ def compute_aves_loss(self, batch: Dict, output: Dict) -> torch.Tensor:
208
+ """
209
+ Compute AVES losses given the input batch and the regression output
210
+ Args:
211
+ batch (Dict): Dictionary containing batch data
212
+ output (Dict): Dictionary containing the regression output
213
+ train (bool): Flag indicating whether it is training or validation mode
214
+ Returns:
215
+ torch.Tensor : Total loss for current batch
216
+ """
217
+ dataset_source = (batch["supercategory"] > 5)
218
+
219
+ pred_params = output['pred_params']
220
+ pred_mask = output['pred_mask']
221
+ pred_keypoints_2d = output['pred_keypoints_2d']
222
+ pred_keypoints_3d = output['pred_keypoints_3d']
223
+
224
+ batch_size = pred_params['pose'].shape[0]
225
+
226
+ # Get annotations
227
+ gt_keypoints_2d = batch['keypoints_2d'][dataset_source][:, :18]
228
+ gt_keypoints_3d = batch['keypoints_3d'][dataset_source][:, :18]
229
+ gt_mask = batch['mask'][dataset_source]
230
+ gt_params = {k: v[dataset_source] for k,v in batch['smal_params'].items()}
231
+ has_params = {k: v[dataset_source] for k,v in batch['has_smal_params'].items()}
232
+ is_axis_angle = {k: v[dataset_source] for k,v in batch['smal_params_is_axis_angle'].items()}
233
+
234
+ # Compute 3D keypoint loss
235
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
236
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
237
+ loss_mask = self.mask_loss(pred_mask, gt_mask)
238
+
239
+ # Compute loss on AVES parameters
240
+ loss_params = {}
241
+ for k, pred in pred_params.items():
242
+ gt = gt_params[k].view(batch_size, -1)
243
+ if is_axis_angle[k].all():
244
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
245
+ has_gt = has_params[k]
246
+ if k == "betas":
247
+ loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
248
+ gt[:, :15].reshape(batch_size, -1),
249
+ has_gt)
250
+ # v1
251
+ loss_params[k+"_re"] = torch.sum(pred[has_gt.bool()] ** 2) + torch.sum(pred[has_gt.bool()] ** 2) * 0.5
252
+ # v2
253
+ # loss_params[k+"_re"] = torch.sum(pred ** 2)
254
+ elif k == "bone":
255
+ loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
256
+ gt.reshape(batch_size, -1),
257
+ has_gt)
258
+ # v1
259
+ loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pred, self.posebone_prior_loss.bone_mean, 1 - has_gt) + \
260
+ self.posebone_prior_loss.l2_loss(pred, self.posebone_prior_loss.bone_mean, has_gt) * 0.02
261
+ # v2
262
+ # loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pred, self.posebone_prior_loss.bone_mean, torch.zeros_like(has_gt))
263
+ elif k == "pose":
264
+ loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
265
+ gt[:, :24].reshape(batch_size, -1),
266
+ has_gt)
267
+ pose_axis_angle = matrix_to_axis_angle(pred)
268
+ # v1
269
+ loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pose_axis_angle.reshape(batch_size, -1), self.posebone_prior_loss.pose_mean, 1 - has_gt) + \
270
+ self.posebone_prior_loss.l2_loss(pose_axis_angle.reshape(batch_size, -1), self.posebone_prior_loss.pose_mean, has_gt) * 0.02
271
+ # v2
272
+ # loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pose_axis_angle.reshape(batch_size, -1), self.posebone_prior_loss.pose_mean, torch.zeros_like(has_gt))
273
+ else:
274
+ loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
275
+ gt.reshape(batch_size, -1),
276
+ has_gt)
277
+
278
+ loss_config = self.cfg.LOSS_WEIGHTS.AVES
279
+ loss = loss_config['KEYPOINTS_3D'] * loss_keypoints_3d + \
280
+ loss_config['KEYPOINTS_2D'] * loss_keypoints_2d + \
281
+ sum([loss_params[k] * loss_config[k.upper()] for k in loss_params]) + \
282
+ loss_config['MASK'] * loss_mask
283
+
284
+ losses = dict(loss_aves=loss.detach(),
285
+ loss_aves_keypoints_2d=loss_keypoints_2d.detach(),
286
+ loss_aves_keypoints_3d=loss_keypoints_3d.detach(),
287
+ loss_aves_mask=loss_mask.detach(),
288
+ )
289
+ for k, v in loss_params.items():
290
+ losses['loss_aves_' + k] = v.detach()
291
+
292
+ return loss, losses
293
+
294
+ def compute_smal_loss(self, batch: Dict, output: Dict) -> torch.Tensor:
295
+ """
296
+ Compute SMAL losses given the input batch and the regression output
297
+ Args:
298
+ batch (Dict): Dictionary containing batch data
299
+ output (Dict): Dictionary containing the regression output
300
+ Returns:
301
+ torch.Tensor : Total loss for current batch
302
+ """
303
+ dataset_source = (batch["supercategory"] < 5)
304
+
305
+ pred_params = output['pred_params']
306
+ pred_keypoints_2d = output['pred_keypoints_2d']
307
+ pred_keypoints_3d = output['pred_keypoints_3d']
308
+
309
+ batch_size = pred_params['pose'].shape[0]
310
+
311
+ # Get annotations
312
+ gt_keypoints_2d = batch['keypoints_2d'][dataset_source]
313
+ gt_keypoints_3d = batch['keypoints_3d'][dataset_source]
314
+ gt_params = {k: v[dataset_source] for k,v in batch['smal_params'].items()}
315
+ has_params = {k: v[dataset_source] for k,v in batch['has_smal_params'].items()}
316
+ is_axis_angle = {k: v[dataset_source] for k,v in batch['smal_params_is_axis_angle'].items()}
317
+
318
+ # Compute 3D keypoint loss
319
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
320
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
321
+
322
+ # Compute loss on SMAL parameters
323
+ loss_smal_params = {}
324
+ for k, pred in pred_params.items():
325
+ gt = gt_params[k].view(batch_size, -1)
326
+ if is_axis_angle[k].all():
327
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
328
+ has_gt = has_params[k]
329
+ if k == "betas":
330
+ loss_smal_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
331
+ gt.reshape(batch_size, -1),
332
+ has_gt) + \
333
+ self.shape_prior_loss(pred, batch["category"][dataset_source], has_gt)
334
+ elif k == "bone":
335
+ continue
336
+ else:
337
+ loss_smal_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
338
+ gt.reshape(batch_size, -1),
339
+ has_gt) + \
340
+ self.pose_prior_loss(torch.cat((pred_params["global_orient"],
341
+ pred_params["pose"]),
342
+ dim=1), has_gt) / 2.
343
+
344
+ loss_config = self.cfg.LOSS_WEIGHTS.SMAL
345
+ loss = loss_config['KEYPOINTS_3D'] * loss_keypoints_3d + \
346
+ loss_config['KEYPOINTS_2D'] * loss_keypoints_2d + \
347
+ sum([loss_smal_params[k] * loss_config[k.upper()] for k in loss_smal_params])
348
+
349
+ losses = dict(loss_smal=loss.detach(),
350
+ loss_smal_keypoints_2d=loss_keypoints_2d.detach(),
351
+ loss_smal_keypoints_3d=loss_keypoints_3d.detach(),
352
+ )
353
+ for k, v in loss_smal_params.items():
354
+ losses['loss_smal_' + k] = v.detach()
355
+
356
+ return loss, losses
357
+
358
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
359
+ """
360
+ Compute losses given the input batch and the regression output
361
+ Args:
362
+ batch (Dict): Dictionary containing batch data
363
+ output (Dict): Dictionary containing the regression output
364
+ train (bool): Flag indicating whether it is training or validation mode
365
+ Returns:
366
+ torch.Tensor : Total loss for current batch
367
+ """
368
+ x = batch['img']
369
+ device, dtype = x.device, x.dtype
370
+ if 'aves_output' in output:
371
+ loss_aves, losses_aves = self.compute_aves_loss(batch, output['aves_output'])
372
+ else:
373
+ loss_aves, losses_aves = torch.tensor(0.0, device=device, dtype=dtype), {}
374
+ if 'smal_output' in output:
375
+ loss_smal, losses_smal = self.compute_smal_loss(batch, output['smal_output'])
376
+ else:
377
+ loss_smal, losses_smal = torch.tensor(0.0, device=device, dtype=dtype), {}
378
+ loss_supcon = self.supcon_loss(output['cls_feats'], labels=batch['category']) if self.cfg.MODEL.BACKBONE.get("USE_CLS", False) \
379
+ else torch.tensor(0.0, device=device, dtype=dtype)
380
+ loss = loss_aves + loss_smal + loss_supcon * self.cfg.LOSS_WEIGHTS['SUPCON']
381
+
382
+ # Saving loss
383
+ losses = {}
384
+ losses['loss'] = loss.detach()
385
+ losses['loss_supcon'] = loss_supcon.detach()
386
+ for k, v in losses_aves.items():
387
+ losses[k] = v.detach()
388
+ for k, v in losses_smal.items():
389
+ losses[k] = v.detach()
390
+ output['losses'] = losses
391
+ return loss
392
+
393
+ # Tensoroboard logging should run from first rank only
394
+ @pl.utilities.rank_zero.rank_zero_only
395
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True,
396
+ write_to_summary_writer: bool = True) -> None:
397
+ """
398
+ Log results to Tensorboard
399
+ Args:
400
+ batch (Dict): Dictionary containing batch data
401
+ output (Dict): Dictionary containing the regression output
402
+ step_count (int): Global training step count
403
+ train (bool): Flag indicating whether it is training or validation mode
404
+ """
405
+
406
+ mode = 'train' if train else 'val'
407
+ batch_size = batch['keypoints_2d'].shape[0]
408
+ images = batch['img']
409
+ masks = batch['mask']
410
+ # mul std then add mean
411
+ images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1))
412
+ images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1))
413
+ masks = masks.unsqueeze(1).repeat(1, 3, 1, 1)
414
+
415
+ gt_keypoints_2d = batch['keypoints_2d']
416
+ losses = output['losses']
417
+ if write_to_summary_writer:
418
+ summary_writer = self.logger.experiment
419
+ for loss_name, val in losses.items():
420
+ summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count)
421
+ if train is False:
422
+ for metric_name, val in output['metric'].items():
423
+ summary_writer.add_scalar(mode + '/' + metric_name, val, step_count)
424
+
425
+ rend_imgs = []
426
+ num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
427
+ dataset_source = (batch["supercategory"] < 5)[:num_images] # bird for index 0
428
+
429
+ num_aves = (num_images - dataset_source[:num_images].sum()).item()
430
+ if num_aves:
431
+ rend_imgs_aves = self.aves_mesh_renderer.visualize_tensorboard( output['aves_output']['pred_vertices'][:num_aves].detach().cpu().numpy(),
432
+ output['aves_output']['pred_cam_t'][:num_aves].detach().cpu().numpy(),
433
+ images[:num_images][~dataset_source].cpu().numpy(),
434
+ self.cfg.AVES.get("FOCAL_LENGTH", 2167),
435
+ output['aves_output']['pred_keypoints_2d'][:num_aves].detach().cpu().numpy(),
436
+ gt_keypoints_2d[:num_images][~dataset_source][:, :18].cpu().numpy(),
437
+ )
438
+ rend_imgs.extend(rend_imgs_aves)
439
+
440
+ num_smal = dataset_source[:num_images].sum().item()
441
+ if num_smal:
442
+ rend_imgs_smal = self.smal_mesh_renderer.visualize_tensorboard( output['smal_output']['pred_vertices'][:num_smal].detach().cpu().numpy(),
443
+ output['smal_output']['pred_cam_t'][:num_smal].detach().cpu().numpy(),
444
+ images[:num_images][dataset_source].cpu().numpy(),
445
+ self.cfg.SMAL.get("FOCAL_LENGTH", 1000),
446
+ output['smal_output']['pred_keypoints_2d'][:num_smal].detach().cpu().numpy(),
447
+ gt_keypoints_2d[:num_images][dataset_source].cpu().numpy(),
448
+ )
449
+ rend_imgs.extend(rend_imgs_smal)
450
+
451
+ rend_imgs = make_grid(rend_imgs, nrow=5, padding=2)
452
+ if write_to_summary_writer:
453
+ summary_writer.add_image('%s/predictions' % mode, rend_imgs, step_count)
454
+
455
+ return rend_imgs
456
+
457
+ def forward(self, batch: Dict) -> Dict:
458
+ """
459
+ Run a forward step of the network in val mode
460
+ Args:
461
+ batch (Dict): Dictionary containing batch data
462
+ Returns:
463
+ Dict: Dictionary containing the regression output
464
+ """
465
+ return self.forward_step(batch, train=False)
466
+
467
+ def training_step(self, batch: Dict) -> Dict:
468
+ """
469
+ Run a full training step
470
+ Args:
471
+ batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d',
472
+ 'aves_params', 'aves_params_is_axis_angle', 'focal_length'}
473
+ Returns:
474
+ Dict: Dictionary containing regression output.
475
+ """
476
+ batch = batch['img']
477
+ optimizer = self.optimizers(use_pl_optimizer=True)
478
+
479
+ batch_size = batch['img'].shape[0]
480
+ output = self.forward_step(batch, train=True)
481
+ if self.cfg.get('UPDATE_GT_SPIN', False):
482
+ self.update_batch_gt_spin(batch, output)
483
+ loss = self.compute_loss(batch, output, train=True)
484
+
485
+ # Error if Nan
486
+ if torch.isnan(loss):
487
+ raise ValueError('Loss is NaN')
488
+
489
+ optimizer.zero_grad()
490
+ self.manual_backward(loss)
491
+ # Clip gradient
492
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
493
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL,
494
+ error_if_nonfinite=True)
495
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
496
+
497
+ optimizer.step()
498
+ if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
499
+ self.tensorboard_logging(batch, output, self.global_step, train=True)
500
+
501
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False,
502
+ batch_size=batch_size, sync_dist=True)
503
+
504
+ return output
505
+
506
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
507
+ pass
508
+
amr/models/aves_warapper.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from dataclasses import dataclass
4
+ from smplx.utils import ModelOutput
5
+ from typing import Optional, NewType
6
+
7
+
8
+ Tensor = NewType('Tensor', torch.Tensor)
9
+ @dataclass
10
+ class AVESOutput(ModelOutput):
11
+ betas: Optional[Tensor] = None
12
+ pose: Optional[Tensor] = None
13
+ bone: Optional[Tensor] = None
14
+
15
+
16
+ class LBS(torch.nn.Module):
17
+ '''
18
+ Implementation of linear blend skinning, with additional bone and scale
19
+ Input:
20
+ V (BN, V, 3): vertices to pose and shape
21
+ pose (BN, J, 3, 3) or (BN, J, 3): pose in rot or axis-angle
22
+ bone (BN, K): allow for direct change of relative joint distances
23
+ scale (1): scale the whole kinematic tree
24
+ '''
25
+ def __init__(self, J, parents, weights):
26
+ super(LBS, self).__init__()
27
+ self.n_joints = J.shape[1]
28
+ self.register_buffer('h_joints', F.pad(J.unsqueeze(-1), [0,0,0,1], value=0))
29
+ self.register_buffer('kin_tree', torch.cat([J[:,[0], :], J[:, 1:]-J[:, parents[1:]]], dim=1).unsqueeze(-1))
30
+
31
+ self.register_buffer('parents', parents)
32
+ self.register_buffer('weights', weights[None].float())
33
+
34
+ def __call__(self, V, pose, bone, scale, to_rotmats=False):
35
+ batch_size = len(V)
36
+ device = pose.device
37
+ V = F.pad(V.unsqueeze(-1), [0,0,0,1], value=1)
38
+ kin_tree = (scale*self.kin_tree) * bone[:, :, None, None]
39
+ pose = pose.view([batch_size, -1, 3, 3])
40
+ T = torch.zeros([batch_size, self.n_joints, 4, 4]).float().to(device)
41
+ T[:, :, -1, -1] = 1
42
+ T[:, :, :3, :] = torch.cat([pose, kin_tree], dim=-1)
43
+ T_rel = [T[:, 0]]
44
+ for i in range(1, self.n_joints):
45
+ T_rel.append(T_rel[self.parents[i]] @ T[:, i])
46
+ T_rel = torch.stack(T_rel, dim=1)
47
+ T_rel[:,:,:,[-1]] -= T_rel.clone() @ (self.h_joints*scale)
48
+ T_ = self.weights @ T_rel.view(batch_size, self.n_joints, -1)
49
+ T_ = T_.view(batch_size, -1, 4, 4)
50
+ V = T_ @ V
51
+ return V[:, :, :3, 0]
52
+
53
+
54
+ class AVES(torch.nn.Module):
55
+ def __init__(self, **kwargs):
56
+ super(AVES, self).__init__()
57
+ # kinematic tree, and map to keypoints from vertices
58
+ self.register_buffer('kintree_table', kwargs['kintree_table'])
59
+ self.register_buffer('parents', kwargs['kintree_table'][0])
60
+ self.register_buffer('weights', kwargs['weights'])
61
+ self.register_buffer('vert2kpt', kwargs['vert2kpt'])
62
+ self.register_buffer('face', kwargs['F'])
63
+
64
+ # mean shape and default joints
65
+ rot = torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], dtype=torch.float32)
66
+ rot = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float32) @ rot
67
+ # rot = torch.eye(3, dtype=torch.float32)
68
+ V = (rot @ kwargs['V'].T).T.unsqueeze(0)
69
+ J = (rot @ kwargs['J'].T).T.unsqueeze(0)
70
+ self.register_buffer('V', V)
71
+ self.register_buffer('J', J)
72
+ self.LBS = LBS(self.J, self.parents, self.weights)
73
+
74
+ # pose and bone prior
75
+ self.register_buffer('p_m', kwargs['pose_mean'])
76
+ self.register_buffer('b_m', kwargs['bone_mean'])
77
+ self.register_buffer('p_cov', kwargs['pose_cov'])
78
+ self.register_buffer('b_cov', kwargs['bone_cov'])
79
+
80
+ # standardized blend shape basis
81
+ B = kwargs['Beta']
82
+ sigma = kwargs['Beta_sigma']
83
+ B = B * sigma[:,None,None]
84
+ self.register_buffer('B', B)
85
+
86
+ # PCA coefficient that is optimized to match the original template shape
87
+ ### so in the __call__ funciton, if beta is set to self.beta_original,
88
+ ### it will return the template shape from ECCV2020 (marcbadger/avian-mesh).
89
+ self.register_buffer('beta_original', kwargs['beta_original'])
90
+
91
+ def __call__(self, global_orient, pose, bone, transl=None,
92
+ scale=1, betas=None, pose2rot=False, **kwargs):
93
+ '''
94
+ Input:
95
+ global_pose [bn, 3] tensor for batched global_pose on root joint
96
+ body_pose [bn, 72] tensor for batched body pose
97
+ bone_length [bn, 24] tensor for bone length; the bone variable
98
+ captures non-rigid joint articulation in this model
99
+
100
+ beta [bn, 15] shape PCA coefficients
101
+ If beta is None, it will return the mean shape
102
+ If beta is self.beta_original, it will return the orignial tempalte shape
103
+
104
+ '''
105
+ device = global_orient.device
106
+ batch_size = global_orient.shape[0]
107
+ V = self.V.repeat([batch_size, 1, 1]) * scale
108
+ J = self.J.repeat([batch_size, 1, 1]) * scale
109
+
110
+ # multi-bird shape space
111
+ if betas is not None:
112
+ V = V + torch.einsum('bk, kmn->bmn', betas, self.B)
113
+
114
+ # concatenate bone and pose
115
+ bone = torch.cat([torch.ones([batch_size, 1]).to(device), bone], dim=1)
116
+ pose = torch.cat([global_orient, pose], dim=1)
117
+
118
+ # LBS
119
+ verts = self.LBS(V, pose, bone, scale, to_rotmats=pose2rot)
120
+ if transl is not None:
121
+ verts = verts + transl[:, None, :]
122
+
123
+ # Calculate 3d keypoint from new vertices resulted from pose
124
+ keypoints = torch.einsum('bni,kn->bki', verts, self.vert2kpt)
125
+
126
+ output = AVESOutput(
127
+ vertices=verts,
128
+ joints=keypoints,
129
+ betas=betas,
130
+ global_orient=global_orient,
131
+ pose=pose,
132
+ bone=bone,
133
+ transl=transl,
134
+ full_pose=None,
135
+ )
136
+ return output
amr/models/backbones/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .vit_moe import vithmoe
2
+ from torch import nn
3
+ import torchvision
4
+
5
+ def create_backbone(cfg):
6
+ if cfg.MODEL.BACKBONE.TYPE == 'vithmoe':
7
+ return vithmoe(cfg)
8
+ else:
9
+ raise NotImplementedError('Backbone type is not implemented')
amr/models/backbones/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (489 Bytes). View file
 
amr/models/backbones/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.71 kB). View file
 
amr/models/backbones/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (401 Bytes). View file
 
amr/models/backbones/__pycache__/rope_deit.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
amr/models/backbones/__pycache__/vit.cpython-310.pyc ADDED
Binary file (11.9 kB). View file