Spaces:
Build error
Build error
Commit ·
1966925
1
Parent(s): 7a90a56
Init Commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +4 -0
- README.md +11 -0
- amr/__init__.py +0 -0
- amr/__pycache__/__init__.cpython-310.pyc +0 -0
- amr/configs/__init__.py +112 -0
- amr/configs/__pycache__/__init__.cpython-310.pyc +0 -0
- amr/datasets/__init__.py +0 -0
- amr/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- amr/datasets/__pycache__/utils.cpython-310.pyc +0 -0
- amr/datasets/__pycache__/vitdet_dataset.cpython-310.pyc +0 -0
- amr/datasets/utils.py +1098 -0
- amr/datasets/vitdet_dataset.py +102 -0
- amr/models/__init__.py +26 -0
- amr/models/__pycache__/__init__.cpython-310.pyc +0 -0
- amr/models/__pycache__/__init__.cpython-312.pyc +0 -0
- amr/models/__pycache__/__init__.cpython-39.pyc +0 -0
- amr/models/__pycache__/amr.cpython-310.pyc +0 -0
- amr/models/__pycache__/amr.cpython-312.pyc +0 -0
- amr/models/__pycache__/amr.cpython-39.pyc +0 -0
- amr/models/__pycache__/animerpp.cpython-310.pyc +0 -0
- amr/models/__pycache__/animerpp.cpython-312.pyc +0 -0
- amr/models/__pycache__/aves_hmr.cpython-310.pyc +0 -0
- amr/models/__pycache__/aves_hmr.cpython-312.pyc +0 -0
- amr/models/__pycache__/aves_warapper.cpython-310.pyc +0 -0
- amr/models/__pycache__/aves_warapper.cpython-312.pyc +0 -0
- amr/models/__pycache__/discriminator.cpython-310.pyc +0 -0
- amr/models/__pycache__/discriminator.cpython-312.pyc +0 -0
- amr/models/__pycache__/discriminator.cpython-39.pyc +0 -0
- amr/models/__pycache__/dyamr.cpython-310.pyc +0 -0
- amr/models/__pycache__/dyamr.cpython-312.pyc +0 -0
- amr/models/__pycache__/losses.cpython-310.pyc +0 -0
- amr/models/__pycache__/losses.cpython-312.pyc +0 -0
- amr/models/__pycache__/losses.cpython-39.pyc +0 -0
- amr/models/__pycache__/predictor.cpython-310.pyc +0 -0
- amr/models/__pycache__/smal_warapper.cpython-310.pyc +0 -0
- amr/models/__pycache__/smal_warapper.cpython-312.pyc +0 -0
- amr/models/__pycache__/smal_warapper.cpython-39.pyc +0 -0
- amr/models/__pycache__/smooth_amr.cpython-310.pyc +0 -0
- amr/models/__pycache__/smooth_amr.cpython-312.pyc +0 -0
- amr/models/__pycache__/smooth_netv2.cpython-310.pyc +0 -0
- amr/models/__pycache__/stamr.cpython-310.pyc +0 -0
- amr/models/__pycache__/stamr.cpython-312.pyc +0 -0
- amr/models/animerpp.py +508 -0
- amr/models/aves_warapper.py +136 -0
- amr/models/backbones/__init__.py +9 -0
- amr/models/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
- amr/models/backbones/__pycache__/__init__.cpython-312.pyc +0 -0
- amr/models/backbones/__pycache__/__init__.cpython-39.pyc +0 -0
- amr/models/backbones/__pycache__/rope_deit.cpython-310.pyc +0 -0
- amr/models/backbones/__pycache__/vit.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.JPEG filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: AniMerPlus
|
|
|
|
| 3 |
emoji: 📊
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: indigo
|
|
@@ -7,6 +8,16 @@ sdk: gradio
|
|
| 7 |
sdk_version: 5.39.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: AniMerPlus
|
| 3 |
+
<<<<<<< HEAD
|
| 4 |
emoji: 📊
|
| 5 |
colorFrom: gray
|
| 6 |
colorTo: indigo
|
|
|
|
| 8 |
sdk_version: 5.39.0
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
=======
|
| 12 |
+
emoji: 🔥
|
| 13 |
+
colorFrom: pink
|
| 14 |
+
colorTo: blue
|
| 15 |
+
sdk: gradio
|
| 16 |
+
sdk_version: 5.1.0
|
| 17 |
+
app_file: app.py
|
| 18 |
+
pinned: false
|
| 19 |
+
license: mit
|
| 20 |
+
>>>>>>> Init Commit
|
| 21 |
---
|
| 22 |
|
| 23 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
amr/__init__.py
ADDED
|
File without changes
|
amr/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
amr/configs/__init__.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from yacs.config import CfgNode as CN
|
| 4 |
+
|
| 5 |
+
CACHE_DIR_AniMer = "./_DATA"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def to_lower(x: Dict) -> Dict:
|
| 9 |
+
"""
|
| 10 |
+
Convert all dictionary keys to lowercase
|
| 11 |
+
Args:
|
| 12 |
+
x (dict): Input dictionary
|
| 13 |
+
Returns:
|
| 14 |
+
dict: Output dictionary with all keys converted to lowercase
|
| 15 |
+
"""
|
| 16 |
+
return {k.lower(): v for k, v in x.items()}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_C = CN(new_allowed=True)
|
| 20 |
+
|
| 21 |
+
_C.GENERAL = CN(new_allowed=True)
|
| 22 |
+
_C.GENERAL.RESUME = True
|
| 23 |
+
_C.GENERAL.TIME_TO_RUN = 3300
|
| 24 |
+
_C.GENERAL.VAL_STEPS = 100
|
| 25 |
+
_C.GENERAL.LOG_STEPS = 100
|
| 26 |
+
_C.GENERAL.CHECKPOINT_STEPS = 20000
|
| 27 |
+
_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
|
| 28 |
+
_C.GENERAL.SUMMARY_DIR = "tensorboard"
|
| 29 |
+
_C.GENERAL.NUM_GPUS = 1
|
| 30 |
+
_C.GENERAL.NUM_WORKERS = 4
|
| 31 |
+
_C.GENERAL.MIXED_PRECISION = True
|
| 32 |
+
_C.GENERAL.ALLOW_CUDA = True
|
| 33 |
+
_C.GENERAL.PIN_MEMORY = False
|
| 34 |
+
_C.GENERAL.DISTRIBUTED = False
|
| 35 |
+
_C.GENERAL.LOCAL_RANK = 0
|
| 36 |
+
_C.GENERAL.USE_SYNCBN = False
|
| 37 |
+
_C.GENERAL.WORLD_SIZE = 1
|
| 38 |
+
|
| 39 |
+
_C.TRAIN = CN(new_allowed=True)
|
| 40 |
+
_C.TRAIN.NUM_EPOCHS = 100
|
| 41 |
+
_C.TRAIN.SHUFFLE = True
|
| 42 |
+
_C.TRAIN.WARMUP = False
|
| 43 |
+
_C.TRAIN.NORMALIZE_PER_IMAGE = False
|
| 44 |
+
_C.TRAIN.CLIP_GRAD = False
|
| 45 |
+
_C.TRAIN.CLIP_GRAD_VALUE = 1.0
|
| 46 |
+
_C.LOSS_WEIGHTS = CN(new_allowed=True)
|
| 47 |
+
|
| 48 |
+
_C.DATASETS = CN(new_allowed=True)
|
| 49 |
+
|
| 50 |
+
_C.MODEL = CN(new_allowed=True)
|
| 51 |
+
_C.MODEL.IMAGE_SIZE = 224
|
| 52 |
+
|
| 53 |
+
_C.EXTRA = CN(new_allowed=True)
|
| 54 |
+
_C.EXTRA.FOCAL_LENGTH = 5000
|
| 55 |
+
|
| 56 |
+
_C.DATASETS.CONFIG = CN(new_allowed=True)
|
| 57 |
+
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
|
| 58 |
+
_C.DATASETS.CONFIG.ROT_FACTOR = 30
|
| 59 |
+
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
|
| 60 |
+
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
|
| 61 |
+
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
|
| 62 |
+
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
|
| 63 |
+
_C.DATASETS.CONFIG.DO_FLIP = False
|
| 64 |
+
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
|
| 65 |
+
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def default_config() -> CN:
|
| 69 |
+
"""
|
| 70 |
+
Get a yacs CfgNode object with the default config values.
|
| 71 |
+
"""
|
| 72 |
+
# Return a clone so that the defaults will not be altered
|
| 73 |
+
# This is for the "local variable" use pattern
|
| 74 |
+
return _C.clone()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def dataset_config() -> CN:
|
| 78 |
+
"""
|
| 79 |
+
Get dataset config file
|
| 80 |
+
Returns:
|
| 81 |
+
CfgNode: Dataset config as a yacs CfgNode object.
|
| 82 |
+
"""
|
| 83 |
+
cfg = CN(new_allowed=True)
|
| 84 |
+
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
|
| 85 |
+
cfg.merge_from_file(config_file)
|
| 86 |
+
cfg.freeze()
|
| 87 |
+
return cfg
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
|
| 91 |
+
"""
|
| 92 |
+
Read a config file and optionally merge it with the default config file.
|
| 93 |
+
Args:
|
| 94 |
+
config_file (str): Path to config file.
|
| 95 |
+
merge (bool): Whether to merge with the default config or not.
|
| 96 |
+
Returns:
|
| 97 |
+
CfgNode: Config as a yacs CfgNode object.
|
| 98 |
+
"""
|
| 99 |
+
if merge:
|
| 100 |
+
cfg = default_config()
|
| 101 |
+
else:
|
| 102 |
+
cfg = CN(new_allowed=True)
|
| 103 |
+
cfg.merge_from_file(config_file)
|
| 104 |
+
|
| 105 |
+
if update_cachedir:
|
| 106 |
+
def update_path(path: str) -> str:
|
| 107 |
+
if os.path.isabs(path):
|
| 108 |
+
return path
|
| 109 |
+
return os.path.join(CACHE_DIR_AniMer, path)
|
| 110 |
+
|
| 111 |
+
cfg.freeze()
|
| 112 |
+
return cfg
|
amr/configs/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
amr/datasets/__init__.py
ADDED
|
File without changes
|
amr/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
amr/datasets/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (36 kB). View file
|
|
|
amr/datasets/__pycache__/vitdet_dataset.cpython-310.pyc
ADDED
|
Binary file (3.19 kB). View file
|
|
|
amr/datasets/utils.py
ADDED
|
@@ -0,0 +1,1098 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parts of the code are taken or adapted from
|
| 3 |
+
https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from skimage.transform import rotate, resize
|
| 8 |
+
from skimage.filters import gaussian
|
| 9 |
+
import random
|
| 10 |
+
import cv2
|
| 11 |
+
from typing import List, Dict, Tuple
|
| 12 |
+
from yacs.config import CfgNode
|
| 13 |
+
from typing import Union
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
|
| 17 |
+
"""Increase the size of the bounding box to match the target shape."""
|
| 18 |
+
if target_aspect_ratio is None:
|
| 19 |
+
return input_shape
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
w, h = input_shape
|
| 23 |
+
except (ValueError, TypeError):
|
| 24 |
+
return input_shape
|
| 25 |
+
|
| 26 |
+
w_t, h_t = target_aspect_ratio
|
| 27 |
+
if h / w < h_t / w_t:
|
| 28 |
+
h_new = max(w * h_t / w_t, h)
|
| 29 |
+
w_new = w
|
| 30 |
+
else:
|
| 31 |
+
h_new = h
|
| 32 |
+
w_new = max(h * w_t / h_t, w)
|
| 33 |
+
if h_new < h or w_new < w:
|
| 34 |
+
breakpoint()
|
| 35 |
+
return np.array([w_new, h_new])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def do_augmentation(aug_config: CfgNode) -> Tuple:
|
| 39 |
+
"""
|
| 40 |
+
Compute random augmentation parameters.
|
| 41 |
+
Args:
|
| 42 |
+
aug_config (CfgNode): Config containing augmentation parameters.
|
| 43 |
+
Returns:
|
| 44 |
+
scale (float): Box rescaling factor.
|
| 45 |
+
rot (float): Random image rotation.
|
| 46 |
+
do_flip (bool): Whether to flip image or not.
|
| 47 |
+
do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
|
| 48 |
+
color_scale (List): Color rescaling factor
|
| 49 |
+
tx (float): Random translation along the x axis.
|
| 50 |
+
ty (float): Random translation along the y axis.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
|
| 54 |
+
ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
|
| 55 |
+
scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
|
| 56 |
+
rot = np.clip(np.random.randn(), -2.0,
|
| 57 |
+
2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
|
| 58 |
+
do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
|
| 59 |
+
do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
|
| 60 |
+
extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
|
| 61 |
+
# extreme_crop_lvl = 0
|
| 62 |
+
c_up = 1.0 + aug_config.COLOR_SCALE
|
| 63 |
+
c_low = 1.0 - aug_config.COLOR_SCALE
|
| 64 |
+
color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
|
| 65 |
+
return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
|
| 69 |
+
"""
|
| 70 |
+
Rotate a 2D point on the x-y plane.
|
| 71 |
+
Args:
|
| 72 |
+
pt_2d (np.array): Input 2D point with shape (2,).
|
| 73 |
+
rot_rad (float): Rotation angle
|
| 74 |
+
Returns:
|
| 75 |
+
np.array: Rotated 2D point.
|
| 76 |
+
"""
|
| 77 |
+
x = pt_2d[0]
|
| 78 |
+
y = pt_2d[1]
|
| 79 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 80 |
+
xx = x * cs - y * sn
|
| 81 |
+
yy = x * sn + y * cs
|
| 82 |
+
return np.array([xx, yy], dtype=np.float32)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def gen_trans_from_patch_cv(c_x: float, c_y: float,
|
| 86 |
+
src_width: float, src_height: float,
|
| 87 |
+
dst_width: float, dst_height: float,
|
| 88 |
+
scale: float, rot: float) -> np.array:
|
| 89 |
+
"""
|
| 90 |
+
Create transformation matrix for the bounding box crop.
|
| 91 |
+
Args:
|
| 92 |
+
c_x (float): Bounding box center x coordinate in the original image.
|
| 93 |
+
c_y (float): Bounding box center y coordinate in the original image.
|
| 94 |
+
src_width (float): Bounding box width.
|
| 95 |
+
src_height (float): Bounding box height.
|
| 96 |
+
dst_width (float): Output box width.
|
| 97 |
+
dst_height (float): Output box height.
|
| 98 |
+
scale (float): Rescaling factor for the bounding box (augmentation).
|
| 99 |
+
rot (float): Random rotation applied to the box.
|
| 100 |
+
Returns:
|
| 101 |
+
trans (np.array): Target geometric transformation.
|
| 102 |
+
"""
|
| 103 |
+
# augment size with scale
|
| 104 |
+
src_w = src_width * scale
|
| 105 |
+
src_h = src_height * scale
|
| 106 |
+
src_center = np.zeros(2)
|
| 107 |
+
src_center[0] = c_x
|
| 108 |
+
src_center[1] = c_y
|
| 109 |
+
# augment rotation
|
| 110 |
+
rot_rad = np.pi * rot / 180
|
| 111 |
+
src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
|
| 112 |
+
src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
|
| 113 |
+
|
| 114 |
+
dst_w = dst_width
|
| 115 |
+
dst_h = dst_height
|
| 116 |
+
dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
|
| 117 |
+
dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
|
| 118 |
+
dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
|
| 119 |
+
|
| 120 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 121 |
+
src[0, :] = src_center
|
| 122 |
+
src[1, :] = src_center + src_downdir
|
| 123 |
+
src[2, :] = src_center + src_rightdir
|
| 124 |
+
|
| 125 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 126 |
+
dst[0, :] = dst_center
|
| 127 |
+
dst[1, :] = dst_center + dst_downdir
|
| 128 |
+
dst[2, :] = dst_center + dst_rightdir
|
| 129 |
+
|
| 130 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 131 |
+
|
| 132 |
+
return trans
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def trans_point2d(pt_2d: np.array, trans: np.array):
|
| 136 |
+
"""
|
| 137 |
+
Transform a 2D point using translation matrix trans.
|
| 138 |
+
Args:
|
| 139 |
+
pt_2d (np.array): Input 2D point with shape (2,).
|
| 140 |
+
trans (np.array): Transformation matrix.
|
| 141 |
+
Returns:
|
| 142 |
+
np.array: Transformed 2D point.
|
| 143 |
+
"""
|
| 144 |
+
src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
|
| 145 |
+
dst_pt = np.dot(trans, src_pt)
|
| 146 |
+
return dst_pt[0:2]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_transform(center, scale, res, rot=0):
|
| 150 |
+
"""Generate transformation matrix."""
|
| 151 |
+
"""Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
|
| 152 |
+
h = 200 * scale
|
| 153 |
+
t = np.zeros((3, 3))
|
| 154 |
+
t[0, 0] = float(res[1]) / h
|
| 155 |
+
t[1, 1] = float(res[0]) / h
|
| 156 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
| 157 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
| 158 |
+
t[2, 2] = 1
|
| 159 |
+
if not rot == 0:
|
| 160 |
+
rot = -rot # To match direction of rotation from cropping
|
| 161 |
+
rot_mat = np.zeros((3, 3))
|
| 162 |
+
rot_rad = rot * np.pi / 180
|
| 163 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 164 |
+
rot_mat[0, :2] = [cs, -sn]
|
| 165 |
+
rot_mat[1, :2] = [sn, cs]
|
| 166 |
+
rot_mat[2, 2] = 1
|
| 167 |
+
# Need to rotate around center
|
| 168 |
+
t_mat = np.eye(3)
|
| 169 |
+
t_mat[0, 2] = -res[1] / 2
|
| 170 |
+
t_mat[1, 2] = -res[0] / 2
|
| 171 |
+
t_inv = t_mat.copy()
|
| 172 |
+
t_inv[:2, 2] *= -1
|
| 173 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
| 174 |
+
return t
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
|
| 178 |
+
"""Transform pixel location to different reference."""
|
| 179 |
+
"""Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
|
| 180 |
+
t = get_transform(center, scale, res, rot=rot)
|
| 181 |
+
if invert:
|
| 182 |
+
t = np.linalg.inv(t)
|
| 183 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
| 184 |
+
new_pt = np.dot(t, new_pt)
|
| 185 |
+
if as_int:
|
| 186 |
+
new_pt = new_pt.astype(int)
|
| 187 |
+
return new_pt[:2] + 1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
|
| 191 |
+
c_x = (ul[0] + br[0]) / 2
|
| 192 |
+
c_y = (ul[1] + br[1]) / 2
|
| 193 |
+
bb_width = patch_width = br[0] - ul[0]
|
| 194 |
+
bb_height = patch_height = br[1] - ul[1]
|
| 195 |
+
trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
|
| 196 |
+
img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
|
| 197 |
+
flags=cv2.INTER_LINEAR,
|
| 198 |
+
borderMode=border_mode,
|
| 199 |
+
borderValue=border_value
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Force borderValue=cv2.BORDER_CONSTANT for alpha channel
|
| 203 |
+
if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
|
| 204 |
+
img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
|
| 205 |
+
flags=cv2.INTER_LINEAR,
|
| 206 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return img_patch
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
|
| 213 |
+
bb_width: float, bb_height: float,
|
| 214 |
+
patch_width: float, patch_height: float,
|
| 215 |
+
do_flip: bool, scale: float, rot: float,
|
| 216 |
+
border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
|
| 217 |
+
"""
|
| 218 |
+
Crop image according to the supplied bounding box.
|
| 219 |
+
Args:
|
| 220 |
+
img (np.array): Input image of shape (H, W, 3)
|
| 221 |
+
c_x (float): Bounding box center x coordinate in the original image.
|
| 222 |
+
c_y (float): Bounding box center y coordinate in the original image.
|
| 223 |
+
bb_width (float): Bounding box width.
|
| 224 |
+
bb_height (float): Bounding box height.
|
| 225 |
+
patch_width (float): Output box width.
|
| 226 |
+
patch_height (float): Output box height.
|
| 227 |
+
do_flip (bool): Whether to flip image or not.
|
| 228 |
+
scale (float): Rescaling factor for the bounding box (augmentation).
|
| 229 |
+
rot (float): Random rotation applied to the box.
|
| 230 |
+
Returns:
|
| 231 |
+
img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
|
| 232 |
+
trans (np.array): Transformation matrix.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
img_height, img_width, img_channels = img.shape
|
| 236 |
+
if do_flip:
|
| 237 |
+
img = img[:, ::-1, :]
|
| 238 |
+
c_x = img_width - c_x - 1
|
| 239 |
+
|
| 240 |
+
trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
|
| 241 |
+
|
| 242 |
+
# img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
|
| 243 |
+
|
| 244 |
+
# skimage
|
| 245 |
+
center = np.zeros(2)
|
| 246 |
+
center[0] = c_x
|
| 247 |
+
center[1] = c_y
|
| 248 |
+
res = np.zeros(2)
|
| 249 |
+
res[0] = patch_width
|
| 250 |
+
res[1] = patch_height
|
| 251 |
+
# assumes bb_width = bb_height
|
| 252 |
+
# assumes patch_width = patch_height
|
| 253 |
+
assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
|
| 254 |
+
assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
|
| 255 |
+
scale1 = scale * bb_width / 200.
|
| 256 |
+
|
| 257 |
+
# Upper left point
|
| 258 |
+
ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
|
| 259 |
+
# Bottom right point
|
| 260 |
+
br = np.array(transform([res[0] + 1,
|
| 261 |
+
res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
|
| 262 |
+
|
| 263 |
+
# Padding so that when rotated proper amount of context is included
|
| 264 |
+
try:
|
| 265 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
|
| 266 |
+
except:
|
| 267 |
+
breakpoint()
|
| 268 |
+
if not rot == 0:
|
| 269 |
+
ul -= pad
|
| 270 |
+
br += pad
|
| 271 |
+
|
| 272 |
+
if False:
|
| 273 |
+
# Old way of cropping image
|
| 274 |
+
ul_int = ul.astype(int)
|
| 275 |
+
br_int = br.astype(int)
|
| 276 |
+
new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
|
| 277 |
+
if len(img.shape) > 2:
|
| 278 |
+
new_shape += [img.shape[2]]
|
| 279 |
+
new_img = np.zeros(new_shape)
|
| 280 |
+
|
| 281 |
+
# Range to fill new array
|
| 282 |
+
new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
|
| 283 |
+
new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
|
| 284 |
+
# Range to sample from original image
|
| 285 |
+
old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
|
| 286 |
+
old_y = max(0, ul_int[1]), min(len(img), br_int[1])
|
| 287 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
| 288 |
+
old_x[0]:old_x[1]]
|
| 289 |
+
|
| 290 |
+
# New way of cropping image
|
| 291 |
+
new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
|
| 292 |
+
|
| 293 |
+
# print(f'{new_img.shape=}')
|
| 294 |
+
# print(f'{new_img1.shape=}')
|
| 295 |
+
# print(f'{np.allclose(new_img, new_img1)=}')
|
| 296 |
+
# print(f'{img.dtype=}')
|
| 297 |
+
|
| 298 |
+
if not rot == 0:
|
| 299 |
+
# Remove padding
|
| 300 |
+
|
| 301 |
+
new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
|
| 302 |
+
new_img = new_img[pad:-pad, pad:-pad]
|
| 303 |
+
|
| 304 |
+
if new_img.shape[0] < 1 or new_img.shape[1] < 1:
|
| 305 |
+
print(f'{img.shape=}')
|
| 306 |
+
print(f'{new_img.shape=}')
|
| 307 |
+
print(f'{ul=}')
|
| 308 |
+
print(f'{br=}')
|
| 309 |
+
print(f'{pad=}')
|
| 310 |
+
print(f'{rot=}')
|
| 311 |
+
|
| 312 |
+
breakpoint()
|
| 313 |
+
|
| 314 |
+
# resize image
|
| 315 |
+
new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
|
| 316 |
+
|
| 317 |
+
new_img = np.clip(new_img, 0, 255).astype(np.uint8)
|
| 318 |
+
|
| 319 |
+
return new_img, trans
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
|
| 323 |
+
bb_width: float, bb_height: float,
|
| 324 |
+
patch_width: float, patch_height: float,
|
| 325 |
+
do_flip: bool, scale: float, rot: float,
|
| 326 |
+
border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
|
| 327 |
+
"""
|
| 328 |
+
Crop the input image and return the crop and the corresponding transformation matrix.
|
| 329 |
+
Args:
|
| 330 |
+
img (np.array): Input image of shape (H, W, 3)
|
| 331 |
+
c_x (float): Bounding box center x coordinate in the original image.
|
| 332 |
+
c_y (float): Bounding box center y coordinate in the original image.
|
| 333 |
+
bb_width (float): Bounding box width.
|
| 334 |
+
bb_height (float): Bounding box height.
|
| 335 |
+
patch_width (float): Output box width.
|
| 336 |
+
patch_height (float): Output box height.
|
| 337 |
+
do_flip (bool): Whether to flip image or not.
|
| 338 |
+
scale (float): Rescaling factor for the bounding box (augmentation).
|
| 339 |
+
rot (float): Random rotation applied to the box.
|
| 340 |
+
Returns:
|
| 341 |
+
img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
|
| 342 |
+
trans (np.array): Transformation matrix.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
img_height, img_width, img_channels = img.shape
|
| 346 |
+
if do_flip:
|
| 347 |
+
img = img[:, ::-1, :]
|
| 348 |
+
c_x = img_width - c_x - 1
|
| 349 |
+
|
| 350 |
+
trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
|
| 351 |
+
|
| 352 |
+
img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
|
| 353 |
+
flags=cv2.INTER_LINEAR,
|
| 354 |
+
borderMode=border_mode,
|
| 355 |
+
borderValue=border_value,
|
| 356 |
+
)
|
| 357 |
+
# Force borderValue=cv2.BORDER_CONSTANT for alpha channel
|
| 358 |
+
if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
|
| 359 |
+
img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
|
| 360 |
+
flags=cv2.INTER_LINEAR,
|
| 361 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
is_border = np.all(img_patch[:, :, :-1] == border_value, axis=2) if img_patch.shape[2] == 4 else np.all(img_patch == 0, axis=2)
|
| 365 |
+
img_border_mask = ~is_border
|
| 366 |
+
return img_patch, trans, img_border_mask
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def convert_cvimg_to_tensor(cvimg: np.array):
|
| 370 |
+
"""
|
| 371 |
+
Convert image from HWC to CHW format.
|
| 372 |
+
Args:
|
| 373 |
+
cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
|
| 374 |
+
Returns:
|
| 375 |
+
np.array: Output image of shape (3, H, W).
|
| 376 |
+
"""
|
| 377 |
+
# from h,w,c(OpenCV) to c,h,w
|
| 378 |
+
img = cvimg.copy()
|
| 379 |
+
img = np.transpose(img, (2, 0, 1))
|
| 380 |
+
# from int to float
|
| 381 |
+
img = img.astype(np.float32)
|
| 382 |
+
return img
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def fliplr_params(smal_params: Dict, has_smal_params: Dict) -> Tuple[Dict, Dict]:
|
| 386 |
+
"""
|
| 387 |
+
Flip SMAL parameters when flipping the image.
|
| 388 |
+
Args:
|
| 389 |
+
smal_params (Dict): SMAL parameter annotations.
|
| 390 |
+
has_smal_params (Dict): Whether SMAL annotations are valid.
|
| 391 |
+
Returns:
|
| 392 |
+
Dict, Dict: Flipped SMAL parameters and valid flags.
|
| 393 |
+
"""
|
| 394 |
+
global_orient = smal_params['global_orient'].copy()
|
| 395 |
+
pose = smal_params['pose'].copy()
|
| 396 |
+
betas = smal_params['betas'].copy()
|
| 397 |
+
translation = smal_params['translation'].copy()
|
| 398 |
+
has_global_orient = has_smal_params['global_orient'].copy()
|
| 399 |
+
has_pose = has_smal_params['pose'].copy()
|
| 400 |
+
has_betas = has_smal_params['betas'].copy()
|
| 401 |
+
has_translation = has_smal_params['translation'].copy()
|
| 402 |
+
|
| 403 |
+
global_orient[1::3] *= -1
|
| 404 |
+
global_orient[2::3] *= -1
|
| 405 |
+
pose[1::3] *= -1
|
| 406 |
+
pose[2::3] *= -1
|
| 407 |
+
translation[1::3] *= -1
|
| 408 |
+
translation[2::3] *= -1
|
| 409 |
+
|
| 410 |
+
smal_params = {'global_orient': global_orient.astype(np.float32),
|
| 411 |
+
'pose': pose.astype(np.float32),
|
| 412 |
+
'betas': betas.astype(np.float32),
|
| 413 |
+
'translation': translation.astype(np.float32)
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
has_smal_params = {'global_orient': has_global_orient,
|
| 417 |
+
'pose': has_pose,
|
| 418 |
+
'betas': has_betas,
|
| 419 |
+
'translation': has_translation
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
return smal_params, has_smal_params
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
|
| 426 |
+
"""
|
| 427 |
+
Flip 2D or 3D keypoints.
|
| 428 |
+
Args:
|
| 429 |
+
joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
|
| 430 |
+
flip_permutation (List): Permutation to apply after flipping.
|
| 431 |
+
Returns:
|
| 432 |
+
np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
|
| 433 |
+
"""
|
| 434 |
+
joints = joints.copy()
|
| 435 |
+
# Flip horizontal
|
| 436 |
+
joints[:, 0] = width - joints[:, 0] - 1
|
| 437 |
+
joints = joints[flip_permutation, :]
|
| 438 |
+
|
| 439 |
+
return joints
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def keypoint_3d_processing(keypoints_3d: np.array, rot: float, filp: bool) -> np.array:
|
| 443 |
+
"""
|
| 444 |
+
Process 3D keypoints (rotation/flipping).
|
| 445 |
+
Args:
|
| 446 |
+
keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
|
| 447 |
+
rot (float): Random rotation applied to the keypoints.
|
| 448 |
+
Returns:
|
| 449 |
+
np.array: Transformed 3D keypoints with shape (N, 4).
|
| 450 |
+
"""
|
| 451 |
+
# in-plane rotation
|
| 452 |
+
rot_mat = np.eye(3, dtype=np.float32)
|
| 453 |
+
if not rot == 0:
|
| 454 |
+
rot_rad = -rot * np.pi / 180
|
| 455 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 456 |
+
rot_mat[0, :2] = [cs, -sn]
|
| 457 |
+
rot_mat[1, :2] = [sn, cs]
|
| 458 |
+
keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
|
| 459 |
+
# flip the x coordinates
|
| 460 |
+
if filp:
|
| 461 |
+
keypoints_3d = fliplr_keypoints(keypoints_3d, list(range(len(keypoints_3d))))
|
| 462 |
+
keypoints_3d = keypoints_3d.astype('float32')
|
| 463 |
+
return keypoints_3d
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def rot_aa(aa: np.array, rot: float) -> np.array:
|
| 467 |
+
"""
|
| 468 |
+
Rotate axis angle parameters.
|
| 469 |
+
Args:
|
| 470 |
+
aa (np.array): Axis-angle vector of shape (3,).
|
| 471 |
+
rot (np.array): Rotation angle in degrees.
|
| 472 |
+
Returns:
|
| 473 |
+
np.array: Rotated axis-angle vector.
|
| 474 |
+
"""
|
| 475 |
+
# pose parameters
|
| 476 |
+
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
| 477 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
| 478 |
+
[0, 0, 1]])
|
| 479 |
+
# find the rotation of the hand in camera frame
|
| 480 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
| 481 |
+
# apply the global rotation to the global orientation
|
| 482 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
| 483 |
+
aa = (resrot.T)[0]
|
| 484 |
+
return aa.astype(np.float32)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def smal_param_processing(smal_params: Dict, has_smal_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
|
| 488 |
+
"""
|
| 489 |
+
Apply random augmentations to the SMAL parameters.
|
| 490 |
+
Args:
|
| 491 |
+
smal_params (Dict): SMAL parameter annotations.
|
| 492 |
+
has_smal_params (Dict): Whether SMAL annotations are valid.
|
| 493 |
+
rot (float): Random rotation applied to the keypoints.
|
| 494 |
+
do_flip (bool): Whether to flip keypoints or not.
|
| 495 |
+
Returns:
|
| 496 |
+
Dict, Dict: Transformed SMAL parameters and valid flags.
|
| 497 |
+
"""
|
| 498 |
+
if do_flip:
|
| 499 |
+
smal_params, has_smal_params = fliplr_params(smal_params, has_smal_params)
|
| 500 |
+
smal_params['global_orient'] = rot_aa(smal_params['global_orient'], rot)
|
| 501 |
+
# camera location is not change, so the translation is not change too.
|
| 502 |
+
# smal_params['transl'] = np.dot(np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
| 503 |
+
# [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
| 504 |
+
# [0, 0, 1]], dtype=np.float32), smal_params['transl'])
|
| 505 |
+
return smal_params, has_smal_params
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def get_example(img_path: Union[str,np.ndarray], center_x: float, center_y: float,
|
| 509 |
+
width: float, height: float,
|
| 510 |
+
keypoints_2d: np.array, keypoints_3d: np.array,
|
| 511 |
+
smal_params: Dict, has_smal_params: Dict,
|
| 512 |
+
patch_width: int, patch_height: int,
|
| 513 |
+
mean: np.array, std: np.array,
|
| 514 |
+
do_augment: bool, augm_config: CfgNode,
|
| 515 |
+
is_bgr: bool = True,
|
| 516 |
+
use_skimage_antialias: bool = False,
|
| 517 |
+
border_mode: int = cv2.BORDER_CONSTANT,
|
| 518 |
+
return_trans: bool = False,) -> Tuple:
|
| 519 |
+
"""
|
| 520 |
+
Get an example from the dataset and (possibly) apply random augmentations.
|
| 521 |
+
Args:
|
| 522 |
+
img_path (str): Image filename
|
| 523 |
+
center_x (float): Bounding box center x coordinate in the original image.
|
| 524 |
+
center_y (float): Bounding box center y coordinate in the original image.
|
| 525 |
+
width (float): Bounding box width.
|
| 526 |
+
height (float): Bounding box height.
|
| 527 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
|
| 528 |
+
keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
|
| 529 |
+
smal_params (Dict): SMAL parameter annotations.
|
| 530 |
+
has_smal_params (Dict): Whether SMAL annotations are valid.
|
| 531 |
+
patch_width (float): Output box width.
|
| 532 |
+
patch_height (float): Output box height.
|
| 533 |
+
mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
|
| 534 |
+
std (np.array): Array of shape (3,) containing the std for normalizing the input image.
|
| 535 |
+
do_augment (bool): Whether to apply data augmentation or not.
|
| 536 |
+
aug_config (CfgNode): Config containing augmentation parameters.
|
| 537 |
+
Returns:
|
| 538 |
+
return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size
|
| 539 |
+
img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
|
| 540 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
|
| 541 |
+
keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
|
| 542 |
+
smal_params (Dict): Transformed SMAL parameters.
|
| 543 |
+
has_smal_params (Dict): Valid flag for transformed SMAL parameters.
|
| 544 |
+
img_size (np.array): Image size of the original image.
|
| 545 |
+
"""
|
| 546 |
+
if isinstance(img_path, str):
|
| 547 |
+
# 1. load image
|
| 548 |
+
cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
|
| 549 |
+
if not isinstance(cvimg, np.ndarray):
|
| 550 |
+
raise IOError("Fail to read %s" % img_path)
|
| 551 |
+
elif isinstance(img_path, np.ndarray):
|
| 552 |
+
cvimg = img_path
|
| 553 |
+
else:
|
| 554 |
+
raise TypeError('img_path must be either a string or a numpy array')
|
| 555 |
+
img_height, img_width, img_channels = cvimg.shape
|
| 556 |
+
|
| 557 |
+
img_size = np.array([img_height, img_width], dtype=np.int32)
|
| 558 |
+
|
| 559 |
+
# 2. get augmentation params
|
| 560 |
+
if do_augment:
|
| 561 |
+
# box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
|
| 562 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
|
| 563 |
+
else:
|
| 564 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
|
| 565 |
+
1.0,
|
| 566 |
+
1.0], 0., 0.
|
| 567 |
+
if width < 1 or height < 1:
|
| 568 |
+
breakpoint()
|
| 569 |
+
|
| 570 |
+
if do_extreme_crop:
|
| 571 |
+
if extreme_crop_lvl == 0:
|
| 572 |
+
center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
|
| 573 |
+
elif extreme_crop_lvl == 1:
|
| 574 |
+
center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height,
|
| 575 |
+
keypoints_2d)
|
| 576 |
+
|
| 577 |
+
THRESH = 4
|
| 578 |
+
if width1 < THRESH or height1 < THRESH:
|
| 579 |
+
pass
|
| 580 |
+
else:
|
| 581 |
+
center_x, center_y, width, height = center_x1, center_y1, width1, height1
|
| 582 |
+
|
| 583 |
+
center_x += width * tx
|
| 584 |
+
center_y += height * ty
|
| 585 |
+
|
| 586 |
+
# Process 3D keypoints
|
| 587 |
+
keypoints_3d = keypoint_3d_processing(keypoints_3d, rot, do_flip)
|
| 588 |
+
|
| 589 |
+
# 3. generate image patch
|
| 590 |
+
if use_skimage_antialias:
|
| 591 |
+
# Blur image to avoid aliasing artifacts
|
| 592 |
+
downsampling_factor = (patch_width / (width * scale))
|
| 593 |
+
if downsampling_factor > 1.1:
|
| 594 |
+
cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True,
|
| 595 |
+
truncate=3.0)
|
| 596 |
+
# augmentation image, translation matrix
|
| 597 |
+
img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
|
| 598 |
+
center_x, center_y,
|
| 599 |
+
width, height,
|
| 600 |
+
patch_width, patch_height,
|
| 601 |
+
do_flip, scale, rot,
|
| 602 |
+
border_mode=border_mode)
|
| 603 |
+
|
| 604 |
+
image = img_patch_cv.copy()
|
| 605 |
+
if is_bgr:
|
| 606 |
+
image = image[:, :, ::-1]
|
| 607 |
+
img_patch_cv = image.copy()
|
| 608 |
+
img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
|
| 609 |
+
|
| 610 |
+
smal_params, has_smal_params = smal_param_processing(smal_params, has_smal_params, rot, do_flip)
|
| 611 |
+
|
| 612 |
+
# apply normalization
|
| 613 |
+
for n_c in range(min(img_channels, 3)):
|
| 614 |
+
img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
|
| 615 |
+
if mean is not None and std is not None:
|
| 616 |
+
img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
|
| 617 |
+
|
| 618 |
+
if do_flip:
|
| 619 |
+
keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
|
| 620 |
+
|
| 621 |
+
for n_jt in range(len(keypoints_2d)):
|
| 622 |
+
keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
|
| 623 |
+
keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
|
| 624 |
+
|
| 625 |
+
if not return_trans:
|
| 626 |
+
return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, img_border_mask
|
| 627 |
+
else:
|
| 628 |
+
return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def get_cub17_example(cvimg: np.array,
|
| 632 |
+
keypoints_2d: np.array,
|
| 633 |
+
center_x: float, center_y: float,
|
| 634 |
+
width: float, height: float,
|
| 635 |
+
patch_width: int, patch_height: int,
|
| 636 |
+
mean: np.array, std: np.array,
|
| 637 |
+
do_augment: bool, augm_config: CfgNode,
|
| 638 |
+
return_trans=True) -> Tuple:
|
| 639 |
+
"""
|
| 640 |
+
Get an example from the dataset and (possibly) apply random augmentations.
|
| 641 |
+
Args:
|
| 642 |
+
cvimg (np.ndarray): Image
|
| 643 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
|
| 644 |
+
center_x (float): Bounding box center x coordinate in the original image.
|
| 645 |
+
center_y (float): Bounding box center y coordinate in the original image.
|
| 646 |
+
width (float): Bounding box width.
|
| 647 |
+
height (float): Bounding box height.
|
| 648 |
+
patch_width (int): Output box width.
|
| 649 |
+
patch_height (int): Output box height.
|
| 650 |
+
mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
|
| 651 |
+
std (np.array): Array of shape (3,) containing the std for normalizing the input image.
|
| 652 |
+
do_augment (bool): Whether to apply data augmentation or not.
|
| 653 |
+
aug_config (CfgNode): Config containing augmentation parameters.
|
| 654 |
+
Returns:
|
| 655 |
+
return img_patch, keypoints_2d
|
| 656 |
+
img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
|
| 657 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
|
| 658 |
+
"""
|
| 659 |
+
img_height, img_width, img_channels = cvimg.shape
|
| 660 |
+
|
| 661 |
+
img_size = np.array([img_height, img_width], dtype=np.int32)
|
| 662 |
+
|
| 663 |
+
# 2. get augmentation params
|
| 664 |
+
if do_augment:
|
| 665 |
+
# box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
|
| 666 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
|
| 667 |
+
else:
|
| 668 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
|
| 669 |
+
1.0,
|
| 670 |
+
1.0], 0., 0.
|
| 671 |
+
# bounding box height and width
|
| 672 |
+
center_x += width * tx
|
| 673 |
+
center_y += height * ty
|
| 674 |
+
# augmentation image, translation matrix
|
| 675 |
+
img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
|
| 676 |
+
center_x, center_y,
|
| 677 |
+
width, height,
|
| 678 |
+
patch_width, patch_height,
|
| 679 |
+
do_flip, scale, rot,
|
| 680 |
+
border_mode=cv2.BORDER_CONSTANT)
|
| 681 |
+
|
| 682 |
+
image = img_patch_cv.copy()
|
| 683 |
+
img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
|
| 684 |
+
|
| 685 |
+
# apply normalization
|
| 686 |
+
for n_c in range(min(img_channels, 3)):
|
| 687 |
+
img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
|
| 688 |
+
if mean is not None and std is not None:
|
| 689 |
+
img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
|
| 690 |
+
|
| 691 |
+
if do_flip:
|
| 692 |
+
keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
|
| 693 |
+
|
| 694 |
+
for n_jt in range(len(keypoints_2d)):
|
| 695 |
+
keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
|
| 696 |
+
keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
|
| 697 |
+
|
| 698 |
+
if return_trans:
|
| 699 |
+
return img_patch, keypoints_2d, img_size, trans, img_border_mask
|
| 700 |
+
else:
|
| 701 |
+
return img_patch, keypoints_2d, img_size, img_border_mask
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
|
| 705 |
+
"""
|
| 706 |
+
Extreme cropping: Crop the box up to the hip locations.
|
| 707 |
+
Args:
|
| 708 |
+
center_x (float): x coordinate of the bounding box center.
|
| 709 |
+
center_y (float): y coordinate of the bounding box center.
|
| 710 |
+
width (float): Bounding box width.
|
| 711 |
+
height (float): Bounding box height.
|
| 712 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 713 |
+
Returns:
|
| 714 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 715 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 716 |
+
width (float): New bounding box width.
|
| 717 |
+
height (float): New bounding box height.
|
| 718 |
+
"""
|
| 719 |
+
keypoints_2d = keypoints_2d.copy()
|
| 720 |
+
lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25 + 0, 25 + 1, 25 + 4, 25 + 5]
|
| 721 |
+
keypoints_2d[lower_body_keypoints, :] = 0
|
| 722 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 723 |
+
center, scale = get_bbox(keypoints_2d)
|
| 724 |
+
center_x = center[0]
|
| 725 |
+
center_y = center[1]
|
| 726 |
+
width = 1.1 * scale[0]
|
| 727 |
+
height = 1.1 * scale[1]
|
| 728 |
+
return center_x, center_y, width, height
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 732 |
+
"""
|
| 733 |
+
Extreme cropping: Crop the box up to the shoulder locations.
|
| 734 |
+
Args:
|
| 735 |
+
center_x (float): x coordinate of the bounding box center.
|
| 736 |
+
center_y (float): y coordinate of the bounding box center.
|
| 737 |
+
width (float): Bounding box width.
|
| 738 |
+
height (float): Bounding box height.
|
| 739 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 740 |
+
Returns:
|
| 741 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 742 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 743 |
+
width (float): New bounding box width.
|
| 744 |
+
height (float): New bounding box height.
|
| 745 |
+
"""
|
| 746 |
+
keypoints_2d = keypoints_2d.copy()
|
| 747 |
+
lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
|
| 748 |
+
[0, 1, 2, 3, 4, 5, 6, 7,
|
| 749 |
+
10, 11, 14, 15, 16]]
|
| 750 |
+
keypoints_2d[lower_body_keypoints, :] = 0
|
| 751 |
+
center, scale = get_bbox(keypoints_2d)
|
| 752 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 753 |
+
center, scale = get_bbox(keypoints_2d)
|
| 754 |
+
center_x = center[0]
|
| 755 |
+
center_y = center[1]
|
| 756 |
+
width = 1.2 * scale[0]
|
| 757 |
+
height = 1.2 * scale[1]
|
| 758 |
+
return center_x, center_y, width, height
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 762 |
+
"""
|
| 763 |
+
Extreme cropping: Crop the box and keep on only the head.
|
| 764 |
+
Args:
|
| 765 |
+
center_x (float): x coordinate of the bounding box center.
|
| 766 |
+
center_y (float): y coordinate of the bounding box center.
|
| 767 |
+
width (float): Bounding box width.
|
| 768 |
+
height (float): Bounding box height.
|
| 769 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 770 |
+
Returns:
|
| 771 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 772 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 773 |
+
width (float): New bounding box width.
|
| 774 |
+
height (float): New bounding box height.
|
| 775 |
+
"""
|
| 776 |
+
keypoints_2d = keypoints_2d.copy()
|
| 777 |
+
lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
|
| 778 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8,
|
| 779 |
+
9, 10, 11, 14, 15, 16]]
|
| 780 |
+
keypoints_2d[lower_body_keypoints, :] = 0
|
| 781 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 782 |
+
center, scale = get_bbox(keypoints_2d)
|
| 783 |
+
center_x = center[0]
|
| 784 |
+
center_y = center[1]
|
| 785 |
+
width = 1.3 * scale[0]
|
| 786 |
+
height = 1.3 * scale[1]
|
| 787 |
+
return center_x, center_y, width, height
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 791 |
+
"""
|
| 792 |
+
Extreme cropping: Crop the box and keep on only the torso.
|
| 793 |
+
Args:
|
| 794 |
+
center_x (float): x coordinate of the bounding box center.
|
| 795 |
+
center_y (float): y coordinate of the bounding box center.
|
| 796 |
+
width (float): Bounding box width.
|
| 797 |
+
height (float): Bounding box height.
|
| 798 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 799 |
+
Returns:
|
| 800 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 801 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 802 |
+
width (float): New bounding box width.
|
| 803 |
+
height (float): New bounding box height.
|
| 804 |
+
"""
|
| 805 |
+
keypoints_2d = keypoints_2d.copy()
|
| 806 |
+
nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in
|
| 807 |
+
[0, 1, 4, 5, 6,
|
| 808 |
+
7, 10, 11, 13,
|
| 809 |
+
17, 18]]
|
| 810 |
+
keypoints_2d[nontorso_body_keypoints, :] = 0
|
| 811 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 812 |
+
center, scale = get_bbox(keypoints_2d)
|
| 813 |
+
center_x = center[0]
|
| 814 |
+
center_y = center[1]
|
| 815 |
+
width = 1.1 * scale[0]
|
| 816 |
+
height = 1.1 * scale[1]
|
| 817 |
+
return center_x, center_y, width, height
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 821 |
+
"""
|
| 822 |
+
Extreme cropping: Crop the box and keep on only the right arm.
|
| 823 |
+
Args:
|
| 824 |
+
center_x (float): x coordinate of the bounding box center.
|
| 825 |
+
center_y (float): y coordinate of the bounding box center.
|
| 826 |
+
width (float): Bounding box width.
|
| 827 |
+
height (float): Bounding box height.
|
| 828 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 829 |
+
Returns:
|
| 830 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 831 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 832 |
+
width (float): New bounding box width.
|
| 833 |
+
height (float): New bounding box height.
|
| 834 |
+
"""
|
| 835 |
+
keypoints_2d = keypoints_2d.copy()
|
| 836 |
+
nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
|
| 837 |
+
25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
|
| 838 |
+
keypoints_2d[nonrightarm_body_keypoints, :] = 0
|
| 839 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 840 |
+
center, scale = get_bbox(keypoints_2d)
|
| 841 |
+
center_x = center[0]
|
| 842 |
+
center_y = center[1]
|
| 843 |
+
width = 1.1 * scale[0]
|
| 844 |
+
height = 1.1 * scale[1]
|
| 845 |
+
return center_x, center_y, width, height
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 849 |
+
"""
|
| 850 |
+
Extreme cropping: Crop the box and keep on only the left arm.
|
| 851 |
+
Args:
|
| 852 |
+
center_x (float): x coordinate of the bounding box center.
|
| 853 |
+
center_y (float): y coordinate of the bounding box center.
|
| 854 |
+
width (float): Bounding box width.
|
| 855 |
+
height (float): Bounding box height.
|
| 856 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 857 |
+
Returns:
|
| 858 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 859 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 860 |
+
width (float): New bounding box width.
|
| 861 |
+
height (float): New bounding box height.
|
| 862 |
+
"""
|
| 863 |
+
keypoints_2d = keypoints_2d.copy()
|
| 864 |
+
nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
|
| 865 |
+
25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
|
| 866 |
+
keypoints_2d[nonleftarm_body_keypoints, :] = 0
|
| 867 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 868 |
+
center, scale = get_bbox(keypoints_2d)
|
| 869 |
+
center_x = center[0]
|
| 870 |
+
center_y = center[1]
|
| 871 |
+
width = 1.1 * scale[0]
|
| 872 |
+
height = 1.1 * scale[1]
|
| 873 |
+
return center_x, center_y, width, height
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 877 |
+
"""
|
| 878 |
+
Extreme cropping: Crop the box and keep on only the legs.
|
| 879 |
+
Args:
|
| 880 |
+
center_x (float): x coordinate of the bounding box center.
|
| 881 |
+
center_y (float): y coordinate of the bounding box center.
|
| 882 |
+
width (float): Bounding box width.
|
| 883 |
+
height (float): Bounding box height.
|
| 884 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 885 |
+
Returns:
|
| 886 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 887 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 888 |
+
width (float): New bounding box width.
|
| 889 |
+
height (float): New bounding box height.
|
| 890 |
+
"""
|
| 891 |
+
keypoints_2d = keypoints_2d.copy()
|
| 892 |
+
nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in
|
| 893 |
+
[6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
|
| 894 |
+
keypoints_2d[nonlegs_body_keypoints, :] = 0
|
| 895 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 896 |
+
center, scale = get_bbox(keypoints_2d)
|
| 897 |
+
center_x = center[0]
|
| 898 |
+
center_y = center[1]
|
| 899 |
+
width = 1.1 * scale[0]
|
| 900 |
+
height = 1.1 * scale[1]
|
| 901 |
+
return center_x, center_y, width, height
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 905 |
+
"""
|
| 906 |
+
Extreme cropping: Crop the box and keep on only the right leg.
|
| 907 |
+
Args:
|
| 908 |
+
center_x (float): x coordinate of the bounding box center.
|
| 909 |
+
center_y (float): y coordinate of the bounding box center.
|
| 910 |
+
width (float): Bounding box width.
|
| 911 |
+
height (float): Bounding box height.
|
| 912 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 913 |
+
Returns:
|
| 914 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 915 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 916 |
+
width (float): New bounding box width.
|
| 917 |
+
height (float): New bounding box height.
|
| 918 |
+
"""
|
| 919 |
+
keypoints_2d = keypoints_2d.copy()
|
| 920 |
+
nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in
|
| 921 |
+
[3, 4, 5, 6, 7,
|
| 922 |
+
8, 9, 10, 11,
|
| 923 |
+
12, 13, 14, 15,
|
| 924 |
+
16, 17, 18]]
|
| 925 |
+
keypoints_2d[nonrightleg_body_keypoints, :] = 0
|
| 926 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 927 |
+
center, scale = get_bbox(keypoints_2d)
|
| 928 |
+
center_x = center[0]
|
| 929 |
+
center_y = center[1]
|
| 930 |
+
width = 1.1 * scale[0]
|
| 931 |
+
height = 1.1 * scale[1]
|
| 932 |
+
return center_x, center_y, width, height
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 936 |
+
"""
|
| 937 |
+
Extreme cropping: Crop the box and keep on only the left leg.
|
| 938 |
+
Args:
|
| 939 |
+
center_x (float): x coordinate of the bounding box center.
|
| 940 |
+
center_y (float): y coordinate of the bounding box center.
|
| 941 |
+
width (float): Bounding box width.
|
| 942 |
+
height (float): Bounding box height.
|
| 943 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 944 |
+
Returns:
|
| 945 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 946 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 947 |
+
width (float): New bounding box width.
|
| 948 |
+
height (float): New bounding box height.
|
| 949 |
+
"""
|
| 950 |
+
keypoints_2d = keypoints_2d.copy()
|
| 951 |
+
nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in
|
| 952 |
+
[0, 1, 2, 6, 7, 8,
|
| 953 |
+
9, 10, 11, 12,
|
| 954 |
+
13, 14, 15, 16,
|
| 955 |
+
17, 18]]
|
| 956 |
+
keypoints_2d[nonleftleg_body_keypoints, :] = 0
|
| 957 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 958 |
+
center, scale = get_bbox(keypoints_2d)
|
| 959 |
+
center_x = center[0]
|
| 960 |
+
center_y = center[1]
|
| 961 |
+
width = 1.1 * scale[0]
|
| 962 |
+
height = 1.1 * scale[1]
|
| 963 |
+
return center_x, center_y, width, height
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def full_body(keypoints_2d: np.array) -> bool:
|
| 967 |
+
"""
|
| 968 |
+
Check if all main body joints are visible.
|
| 969 |
+
Args:
|
| 970 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 971 |
+
Returns:
|
| 972 |
+
bool: True if all main body joints are visible.
|
| 973 |
+
"""
|
| 974 |
+
|
| 975 |
+
body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
|
| 976 |
+
body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
|
| 977 |
+
return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(
|
| 978 |
+
body_keypoints)
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def upper_body(keypoints_2d: np.array):
|
| 982 |
+
"""
|
| 983 |
+
Check if all upper body joints are visible.
|
| 984 |
+
Args:
|
| 985 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 986 |
+
Returns:
|
| 987 |
+
bool: True if all main body joints are visible.
|
| 988 |
+
"""
|
| 989 |
+
lower_body_keypoints_openpose = [10, 11, 13, 14]
|
| 990 |
+
lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
|
| 991 |
+
upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
|
| 992 |
+
upper_body_keypoints = [25 + 8, 25 + 9, 25 + 12, 25 + 13, 25 + 17, 25 + 18]
|
| 993 |
+
return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0) \
|
| 994 |
+
and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
|
| 998 |
+
"""
|
| 999 |
+
Get center and scale for bounding box from openpose detections.
|
| 1000 |
+
Args:
|
| 1001 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 1002 |
+
rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
|
| 1003 |
+
Returns:
|
| 1004 |
+
center (np.array): Array of shape (2,) containing the new bounding box center.
|
| 1005 |
+
scale (float): New bounding box scale.
|
| 1006 |
+
"""
|
| 1007 |
+
valid = keypoints_2d[:, -1] > 0
|
| 1008 |
+
valid_keypoints = keypoints_2d[valid][:, :-1]
|
| 1009 |
+
center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
|
| 1010 |
+
bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
|
| 1011 |
+
# adjust bounding box tightness
|
| 1012 |
+
scale = bbox_size
|
| 1013 |
+
scale *= rescale
|
| 1014 |
+
return center, scale
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
|
| 1018 |
+
"""
|
| 1019 |
+
Perform extreme cropping
|
| 1020 |
+
Args:
|
| 1021 |
+
center_x (float): x coordinate of bounding box center.
|
| 1022 |
+
center_y (float): y coordinate of bounding box center.
|
| 1023 |
+
width (float): bounding box width.
|
| 1024 |
+
height (float): bounding box height.
|
| 1025 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 1026 |
+
rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
|
| 1027 |
+
Returns:
|
| 1028 |
+
center_x (float): x coordinate of bounding box center.
|
| 1029 |
+
center_y (float): y coordinate of bounding box center.
|
| 1030 |
+
width (float): bounding box width.
|
| 1031 |
+
height (float): bounding box height.
|
| 1032 |
+
"""
|
| 1033 |
+
p = torch.rand(1).item()
|
| 1034 |
+
if full_body(keypoints_2d):
|
| 1035 |
+
if p < 0.7:
|
| 1036 |
+
center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
|
| 1037 |
+
elif p < 0.9:
|
| 1038 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1039 |
+
else:
|
| 1040 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1041 |
+
elif upper_body(keypoints_2d):
|
| 1042 |
+
if p < 0.9:
|
| 1043 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1044 |
+
else:
|
| 1045 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1046 |
+
|
| 1047 |
+
return center_x, center_y, max(width, height), max(width, height)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float,
|
| 1051 |
+
keypoints_2d: np.array) -> Tuple:
|
| 1052 |
+
"""
|
| 1053 |
+
Perform aggressive extreme cropping
|
| 1054 |
+
Args:
|
| 1055 |
+
center_x (float): x coordinate of bounding box center.
|
| 1056 |
+
center_y (float): y coordinate of bounding box center.
|
| 1057 |
+
width (float): bounding box width.
|
| 1058 |
+
height (float): bounding box height.
|
| 1059 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 1060 |
+
rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
|
| 1061 |
+
Returns:
|
| 1062 |
+
center_x (float): x coordinate of bounding box center.
|
| 1063 |
+
center_y (float): y coordinate of bounding box center.
|
| 1064 |
+
width (float): bounding box width.
|
| 1065 |
+
height (float): bounding box height.
|
| 1066 |
+
"""
|
| 1067 |
+
p = torch.rand(1).item()
|
| 1068 |
+
if full_body(keypoints_2d):
|
| 1069 |
+
if p < 0.2:
|
| 1070 |
+
center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
|
| 1071 |
+
elif p < 0.3:
|
| 1072 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1073 |
+
elif p < 0.4:
|
| 1074 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1075 |
+
elif p < 0.5:
|
| 1076 |
+
center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
|
| 1077 |
+
elif p < 0.6:
|
| 1078 |
+
center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1079 |
+
elif p < 0.7:
|
| 1080 |
+
center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1081 |
+
elif p < 0.8:
|
| 1082 |
+
center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
|
| 1083 |
+
elif p < 0.9:
|
| 1084 |
+
center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
|
| 1085 |
+
else:
|
| 1086 |
+
center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
|
| 1087 |
+
elif upper_body(keypoints_2d):
|
| 1088 |
+
if p < 0.2:
|
| 1089 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1090 |
+
elif p < 0.4:
|
| 1091 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1092 |
+
elif p < 0.6:
|
| 1093 |
+
center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
|
| 1094 |
+
elif p < 0.8:
|
| 1095 |
+
center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1096 |
+
else:
|
| 1097 |
+
center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1098 |
+
return center_x, center_y, max(width, height), max(width, height)
|
amr/datasets/vitdet_dataset.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.filters import gaussian
|
| 6 |
+
from yacs.config import CfgNode
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .utils import (convert_cvimg_to_tensor,
|
| 10 |
+
expand_to_aspect_ratio,
|
| 11 |
+
generate_image_patch_cv2)
|
| 12 |
+
|
| 13 |
+
DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
|
| 14 |
+
DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ViTDetDataset(torch.utils.data.Dataset):
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
cfg: CfgNode,
|
| 21 |
+
img_cv2: np.array,
|
| 22 |
+
boxes: np.array,
|
| 23 |
+
category: List[int],
|
| 24 |
+
rescale_factor=1,
|
| 25 |
+
train: bool = False,
|
| 26 |
+
**kwargs):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.cfg = cfg
|
| 29 |
+
self.img_cv2 = img_cv2
|
| 30 |
+
self.boxes = boxes
|
| 31 |
+
self.category = category
|
| 32 |
+
self.focal_length = []
|
| 33 |
+
for c in category:
|
| 34 |
+
if c == 6:
|
| 35 |
+
self.focal_length.append([cfg.AVES.FOCAL_LENGTH, cfg.AVES.FOCAL_LENGTH])
|
| 36 |
+
else:
|
| 37 |
+
self.focal_length.append([cfg.SMAL.FOCAL_LENGTH, cfg.SMAL.FOCAL_LENGTH])
|
| 38 |
+
|
| 39 |
+
assert train is False, "ViTDetDataset is only for inference"
|
| 40 |
+
self.train = train
|
| 41 |
+
self.img_size = cfg.MODEL.IMAGE_SIZE
|
| 42 |
+
self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
|
| 43 |
+
self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
|
| 44 |
+
|
| 45 |
+
# Preprocess annotations
|
| 46 |
+
boxes = boxes.astype(np.float32)
|
| 47 |
+
self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
|
| 48 |
+
self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
|
| 49 |
+
self.animalid = np.arange(len(boxes), dtype=np.int32)
|
| 50 |
+
|
| 51 |
+
def __len__(self) -> int:
|
| 52 |
+
return len(self.animalid)
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, idx: int) -> Dict[str, np.array]:
|
| 55 |
+
|
| 56 |
+
center = self.center[idx].copy()
|
| 57 |
+
center_x = center[0]
|
| 58 |
+
center_y = center[1]
|
| 59 |
+
|
| 60 |
+
scale = self.scale[idx]
|
| 61 |
+
BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
|
| 62 |
+
bbox_size = expand_to_aspect_ratio(scale * 200, target_aspect_ratio=BBOX_SHAPE).max()
|
| 63 |
+
|
| 64 |
+
patch_width = patch_height = self.img_size
|
| 65 |
+
|
| 66 |
+
flip = False
|
| 67 |
+
|
| 68 |
+
# 3. generate image patch
|
| 69 |
+
# if use_skimage_antialias:
|
| 70 |
+
cvimg = self.img_cv2.copy()
|
| 71 |
+
if True:
|
| 72 |
+
# Blur image to avoid aliasing artifacts
|
| 73 |
+
downsampling_factor = ((bbox_size * 1.0) / patch_width)
|
| 74 |
+
print(f'{downsampling_factor=}')
|
| 75 |
+
downsampling_factor = downsampling_factor / 2.0
|
| 76 |
+
if downsampling_factor > 1.1:
|
| 77 |
+
cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True)
|
| 78 |
+
|
| 79 |
+
img_patch_cv, trans, _ = generate_image_patch_cv2(cvimg,
|
| 80 |
+
center_x, center_y,
|
| 81 |
+
bbox_size, bbox_size,
|
| 82 |
+
patch_width, patch_height,
|
| 83 |
+
flip, 1.0, 0.0,
|
| 84 |
+
border_mode=cv2.BORDER_CONSTANT)
|
| 85 |
+
img_patch_cv = img_patch_cv[:, :, ::-1]
|
| 86 |
+
img_patch = convert_cvimg_to_tensor(img_patch_cv)
|
| 87 |
+
|
| 88 |
+
# apply normalization
|
| 89 |
+
for n_c in range(min(self.img_cv2.shape[2], 3)):
|
| 90 |
+
img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
|
| 91 |
+
|
| 92 |
+
item = {
|
| 93 |
+
'img': img_patch,
|
| 94 |
+
'animalid': int(self.animalid[idx]),
|
| 95 |
+
'box_center': self.center[idx].copy(),
|
| 96 |
+
'box_size': bbox_size,
|
| 97 |
+
'img_size': 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]),
|
| 98 |
+
'focal_length': np.array(self.focal_length, dtype=np.float32)[idx],
|
| 99 |
+
'supercategory': self.category[idx],
|
| 100 |
+
'has_mask': np.array(0., dtype=np.float32)
|
| 101 |
+
}
|
| 102 |
+
return item
|
amr/models/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .smal_warapper import SMAL
|
| 2 |
+
from .aves_warapper import AVES
|
| 3 |
+
from .animerpp import AniMerPlusPlus
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_amr(checkpoint_path):
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from ..configs import get_config
|
| 9 |
+
model_cfg = str(Path(checkpoint_path).parent / 'config.yaml')
|
| 10 |
+
model_cfg = get_config(model_cfg, update_cachedir=True)
|
| 11 |
+
|
| 12 |
+
# Override some config values, to crop bbox correctly
|
| 13 |
+
if (model_cfg.MODEL.BACKBONE.TYPE in ['vith', 'vithmoe']) and ('BBOX_SHAPE' not in model_cfg.MODEL):
|
| 14 |
+
model_cfg.defrost()
|
| 15 |
+
assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
|
| 16 |
+
model_cfg.MODEL.BBOX_SHAPE = [192, 256]
|
| 17 |
+
model_cfg.freeze()
|
| 18 |
+
|
| 19 |
+
# Update config to be compatible with demo
|
| 20 |
+
if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
|
| 21 |
+
model_cfg.defrost()
|
| 22 |
+
model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
|
| 23 |
+
model_cfg.freeze()
|
| 24 |
+
|
| 25 |
+
model = AniMerPlusPlus.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg, map_location="cpu")
|
| 26 |
+
return model, model_cfg
|
amr/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (999 Bytes). View file
|
|
|
amr/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
amr/models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
amr/models/__pycache__/amr.cpython-310.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
amr/models/__pycache__/amr.cpython-312.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
amr/models/__pycache__/amr.cpython-39.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
amr/models/__pycache__/animerpp.cpython-310.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
amr/models/__pycache__/animerpp.cpython-312.pyc
ADDED
|
Binary file (39.2 kB). View file
|
|
|
amr/models/__pycache__/aves_hmr.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
amr/models/__pycache__/aves_hmr.cpython-312.pyc
ADDED
|
Binary file (31 kB). View file
|
|
|
amr/models/__pycache__/aves_warapper.cpython-310.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
amr/models/__pycache__/aves_warapper.cpython-312.pyc
ADDED
|
Binary file (8.42 kB). View file
|
|
|
amr/models/__pycache__/discriminator.cpython-310.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
amr/models/__pycache__/discriminator.cpython-312.pyc
ADDED
|
Binary file (7.94 kB). View file
|
|
|
amr/models/__pycache__/discriminator.cpython-39.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
amr/models/__pycache__/dyamr.cpython-310.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
amr/models/__pycache__/dyamr.cpython-312.pyc
ADDED
|
Binary file (34.3 kB). View file
|
|
|
amr/models/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
amr/models/__pycache__/losses.cpython-312.pyc
ADDED
|
Binary file (24.4 kB). View file
|
|
|
amr/models/__pycache__/losses.cpython-39.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
amr/models/__pycache__/predictor.cpython-310.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
amr/models/__pycache__/smal_warapper.cpython-310.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
amr/models/__pycache__/smal_warapper.cpython-312.pyc
ADDED
|
Binary file (8.66 kB). View file
|
|
|
amr/models/__pycache__/smal_warapper.cpython-39.pyc
ADDED
|
Binary file (5.53 kB). View file
|
|
|
amr/models/__pycache__/smooth_amr.cpython-310.pyc
ADDED
|
Binary file (7.19 kB). View file
|
|
|
amr/models/__pycache__/smooth_amr.cpython-312.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
amr/models/__pycache__/smooth_netv2.cpython-310.pyc
ADDED
|
Binary file (9.97 kB). View file
|
|
|
amr/models/__pycache__/stamr.cpython-310.pyc
ADDED
|
Binary file (6.27 kB). View file
|
|
|
amr/models/__pycache__/stamr.cpython-312.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
amr/models/animerpp.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pickle
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torchvision.utils import make_grid
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from pytorch3d.transforms import matrix_to_axis_angle
|
| 7 |
+
from yacs.config import CfgNode
|
| 8 |
+
from ..utils import MeshRenderer
|
| 9 |
+
from ..utils.geometry import aa_to_rotmat, perspective_projection
|
| 10 |
+
from ..utils.pylogger import get_pylogger
|
| 11 |
+
from ..utils.mesh_renderer import SilhouetteRenderer
|
| 12 |
+
from .backbones import create_backbone
|
| 13 |
+
from .heads.classifier_head import ClassTokenHead
|
| 14 |
+
from .heads import build_aves_head, build_smal_head
|
| 15 |
+
from .losses import (Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, SupConLoss,
|
| 16 |
+
PoseBonePriorLoss, SilhouetteLoss, ShapePriorLoss, PosePriorLoss)
|
| 17 |
+
from .aves_warapper import AVES
|
| 18 |
+
from .smal_warapper import SMAL
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
log = get_pylogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AniMerPlusPlus(pl.LightningModule):
|
| 25 |
+
def __init__(self, cfg: CfgNode, init_renderer: bool = True):
|
| 26 |
+
"""
|
| 27 |
+
Setup AVES-HMR model
|
| 28 |
+
Args:
|
| 29 |
+
cfg (CfgNode): Config file as a yacs CfgNode
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
# Save hyperparameters
|
| 34 |
+
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
|
| 35 |
+
|
| 36 |
+
self.cfg = cfg
|
| 37 |
+
# Create backbone feature extractor
|
| 38 |
+
self.backbone = create_backbone(cfg)
|
| 39 |
+
|
| 40 |
+
# Create AVES head
|
| 41 |
+
self.aves_head = build_aves_head(cfg)
|
| 42 |
+
|
| 43 |
+
# Create SMAL head
|
| 44 |
+
self.smal_head = build_smal_head(cfg)
|
| 45 |
+
|
| 46 |
+
self.class_token_head = ClassTokenHead(**cfg.MODEL.get("CLASS_TOKEN_HEAD", dict()))
|
| 47 |
+
|
| 48 |
+
# Define loss functions
|
| 49 |
+
# common loss
|
| 50 |
+
self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
|
| 51 |
+
self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
|
| 52 |
+
self.supcon_loss = SupConLoss()
|
| 53 |
+
self.parameter_loss = ParameterLoss()
|
| 54 |
+
# aves loss
|
| 55 |
+
self.posebone_prior_loss = PoseBonePriorLoss(path_prior=cfg.AVES.POSE_PRIOR_PATH)
|
| 56 |
+
self.mask_loss = SilhouetteLoss()
|
| 57 |
+
# smal loss
|
| 58 |
+
self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH)
|
| 59 |
+
self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH)
|
| 60 |
+
# Instantiate AVES model
|
| 61 |
+
aves_model_path = cfg.AVES.MODEL_PATH
|
| 62 |
+
aves_cfg = torch.load(aves_model_path, weights_only=True)
|
| 63 |
+
self.aves = AVES(**aves_cfg)
|
| 64 |
+
|
| 65 |
+
# Instantiate SMAL model
|
| 66 |
+
smal_model_path = cfg.SMAL.MODEL_PATH
|
| 67 |
+
with open(smal_model_path, 'rb') as f:
|
| 68 |
+
smal_cfg = pickle.load(f, encoding="latin1")
|
| 69 |
+
self.smal = SMAL(**smal_cfg)
|
| 70 |
+
|
| 71 |
+
# Buffer that shows whetheer we need to initialize ActNorm layers
|
| 72 |
+
self.register_buffer('initialized', torch.tensor(False))
|
| 73 |
+
# Setup renderer for visualization
|
| 74 |
+
if init_renderer:
|
| 75 |
+
self.aves_mesh_renderer = MeshRenderer(self.cfg, faces=aves_cfg['F'].numpy())
|
| 76 |
+
self.smal_mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy())
|
| 77 |
+
else:
|
| 78 |
+
self.renderer = None
|
| 79 |
+
self.mesh_renderer = None
|
| 80 |
+
|
| 81 |
+
# Only appling for AVES training
|
| 82 |
+
self.aves_silouette_render = SilhouetteRenderer(size=self.cfg.MODEL.IMAGE_SIZE,
|
| 83 |
+
focal=self.cfg.AVES.get("FOCAL_LENGTH", 2167),
|
| 84 |
+
device='cuda')
|
| 85 |
+
|
| 86 |
+
self.automatic_optimization = False
|
| 87 |
+
|
| 88 |
+
def get_parameters(self):
|
| 89 |
+
all_params = list(self.aves_head.parameters())
|
| 90 |
+
all_params += list(self.backbone.parameters())
|
| 91 |
+
all_params += list(self.smal_head.parameters())
|
| 92 |
+
all_params += list(self.class_token_head.parameters())
|
| 93 |
+
return all_params
|
| 94 |
+
|
| 95 |
+
def configure_optimizers(self):
|
| 96 |
+
"""
|
| 97 |
+
Setup model and distriminator Optimizers
|
| 98 |
+
Returns:
|
| 99 |
+
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
|
| 100 |
+
"""
|
| 101 |
+
param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
|
| 102 |
+
|
| 103 |
+
if "vit" in self.cfg.MODEL.BACKBONE.TYPE:
|
| 104 |
+
optimizer = torch.optim.AdamW(params=param_groups,
|
| 105 |
+
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
|
| 106 |
+
else:
|
| 107 |
+
optimizer = torch.optim.Adam(params=param_groups,
|
| 108 |
+
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
|
| 109 |
+
return optimizer
|
| 110 |
+
|
| 111 |
+
def forward_backbone(self, batch: Dict):
|
| 112 |
+
x = batch['img']
|
| 113 |
+
dataset_source = batch["supercategory"] < 5 # bird for index 0
|
| 114 |
+
# Compute conditioning features using the backbone
|
| 115 |
+
if self.cfg.MODEL.BACKBONE.TYPE in ["vith"]:
|
| 116 |
+
conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # [256, 192]
|
| 117 |
+
elif self.cfg.MODEL.BACKBONE.TYPE in ["vithmoe"]:
|
| 118 |
+
conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32], dataset_source=dataset_source.type(torch.long))
|
| 119 |
+
else:
|
| 120 |
+
conditioning_feats = self.backbone(x)
|
| 121 |
+
cls = None
|
| 122 |
+
return conditioning_feats, cls
|
| 123 |
+
|
| 124 |
+
def forward_one_parametric_model(self,
|
| 125 |
+
focal_length: torch.tensor,
|
| 126 |
+
features: torch.tensor,
|
| 127 |
+
head: torch.nn.Module,
|
| 128 |
+
parametric_model: torch.nn.Module,):
|
| 129 |
+
"""
|
| 130 |
+
Run a forward step of one parametric model.
|
| 131 |
+
Args:
|
| 132 |
+
batch (Dict): Dictionary containing batch data
|
| 133 |
+
Returns:
|
| 134 |
+
Dict: Dictionary containing the regression output
|
| 135 |
+
"""
|
| 136 |
+
batch_size = features.shape[0]
|
| 137 |
+
pred_params, pred_cam, _ = head(features)
|
| 138 |
+
# Store useful regression outputs to the output dict
|
| 139 |
+
output = {}
|
| 140 |
+
output['pred_cam'] = pred_cam
|
| 141 |
+
output['pred_params'] = {k: v.clone() for k, v in pred_params.items()}
|
| 142 |
+
|
| 143 |
+
# Compute camera translation
|
| 144 |
+
pred_cam_t = torch.stack([pred_cam[:, 1],
|
| 145 |
+
pred_cam[:, 2],
|
| 146 |
+
2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)], dim=-1)
|
| 147 |
+
output['pred_cam_t'] = pred_cam_t
|
| 148 |
+
output['focal_length'] = focal_length
|
| 149 |
+
|
| 150 |
+
# Compute model vertices, joints and the projected joints
|
| 151 |
+
pred_params['global_orient'] = pred_params['global_orient'].reshape(batch_size, -1, 3, 3)
|
| 152 |
+
pred_params['pose'] = pred_params['pose'].reshape(batch_size, -1, 3, 3)
|
| 153 |
+
pred_params['betas'] = pred_params['betas'].reshape(batch_size, -1)
|
| 154 |
+
pred_params['bone'] = pred_params['bone'].reshape(batch_size, -1) if 'bone' in pred_params else None
|
| 155 |
+
parametric_model_output = parametric_model(**pred_params, pose2rot=False)
|
| 156 |
+
|
| 157 |
+
pred_keypoints_3d = parametric_model_output.joints
|
| 158 |
+
pred_vertices = parametric_model_output.vertices
|
| 159 |
+
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
|
| 160 |
+
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
|
| 161 |
+
pred_cam_t = pred_cam_t.reshape(-1, 3)
|
| 162 |
+
focal_length = focal_length.reshape(-1, 2)
|
| 163 |
+
pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
|
| 164 |
+
translation=pred_cam_t,
|
| 165 |
+
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
|
| 166 |
+
output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
|
| 167 |
+
return output
|
| 168 |
+
|
| 169 |
+
def forward_step(self, batch: Dict, train: bool = False) -> Dict:
|
| 170 |
+
"""
|
| 171 |
+
Run a forward step of the network
|
| 172 |
+
Args:
|
| 173 |
+
batch (Dict): Dictionary containing batch data
|
| 174 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 175 |
+
Returns:
|
| 176 |
+
Dict: Dictionary containing the regression output
|
| 177 |
+
"""
|
| 178 |
+
# Use RGB image as input
|
| 179 |
+
x = batch['img']
|
| 180 |
+
batch_size = x.shape[0]
|
| 181 |
+
device = x.device
|
| 182 |
+
dataset_source = (batch["supercategory"] < 5) # bird for index 0
|
| 183 |
+
|
| 184 |
+
features, cls = self.forward_backbone(batch)
|
| 185 |
+
|
| 186 |
+
output = dict()
|
| 187 |
+
output['cls_feats'] = self.class_token_head(cls) if self.cfg.MODEL.BACKBONE.get("USE_CLS", False) else None
|
| 188 |
+
|
| 189 |
+
num_aves = (batch_size - dataset_source.sum()).item()
|
| 190 |
+
if num_aves:
|
| 191 |
+
output['aves_output'] = self.forward_one_parametric_model(batch['focal_length'][~dataset_source],
|
| 192 |
+
features[~dataset_source],
|
| 193 |
+
self.aves_head,
|
| 194 |
+
self.aves)
|
| 195 |
+
# Only specific to AVES training
|
| 196 |
+
output['aves_output']['pred_mask'] = self.aves_silouette_render(output['aves_output']['pred_vertices']+output['aves_output']['pred_cam_t'].unsqueeze(1),
|
| 197 |
+
faces=self.aves.face.unsqueeze(0).repeat(batch_size-dataset_source.sum().item(), 1, 1).to(device))
|
| 198 |
+
|
| 199 |
+
num_smal = dataset_source.sum().item()
|
| 200 |
+
if num_smal:
|
| 201 |
+
output['smal_output'] = self.forward_one_parametric_model(batch['focal_length'][dataset_source],
|
| 202 |
+
features[dataset_source],
|
| 203 |
+
self.smal_head,
|
| 204 |
+
self.smal)
|
| 205 |
+
return output
|
| 206 |
+
|
| 207 |
+
def compute_aves_loss(self, batch: Dict, output: Dict) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Compute AVES losses given the input batch and the regression output
|
| 210 |
+
Args:
|
| 211 |
+
batch (Dict): Dictionary containing batch data
|
| 212 |
+
output (Dict): Dictionary containing the regression output
|
| 213 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 214 |
+
Returns:
|
| 215 |
+
torch.Tensor : Total loss for current batch
|
| 216 |
+
"""
|
| 217 |
+
dataset_source = (batch["supercategory"] > 5)
|
| 218 |
+
|
| 219 |
+
pred_params = output['pred_params']
|
| 220 |
+
pred_mask = output['pred_mask']
|
| 221 |
+
pred_keypoints_2d = output['pred_keypoints_2d']
|
| 222 |
+
pred_keypoints_3d = output['pred_keypoints_3d']
|
| 223 |
+
|
| 224 |
+
batch_size = pred_params['pose'].shape[0]
|
| 225 |
+
|
| 226 |
+
# Get annotations
|
| 227 |
+
gt_keypoints_2d = batch['keypoints_2d'][dataset_source][:, :18]
|
| 228 |
+
gt_keypoints_3d = batch['keypoints_3d'][dataset_source][:, :18]
|
| 229 |
+
gt_mask = batch['mask'][dataset_source]
|
| 230 |
+
gt_params = {k: v[dataset_source] for k,v in batch['smal_params'].items()}
|
| 231 |
+
has_params = {k: v[dataset_source] for k,v in batch['has_smal_params'].items()}
|
| 232 |
+
is_axis_angle = {k: v[dataset_source] for k,v in batch['smal_params_is_axis_angle'].items()}
|
| 233 |
+
|
| 234 |
+
# Compute 3D keypoint loss
|
| 235 |
+
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
|
| 236 |
+
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
|
| 237 |
+
loss_mask = self.mask_loss(pred_mask, gt_mask)
|
| 238 |
+
|
| 239 |
+
# Compute loss on AVES parameters
|
| 240 |
+
loss_params = {}
|
| 241 |
+
for k, pred in pred_params.items():
|
| 242 |
+
gt = gt_params[k].view(batch_size, -1)
|
| 243 |
+
if is_axis_angle[k].all():
|
| 244 |
+
gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
|
| 245 |
+
has_gt = has_params[k]
|
| 246 |
+
if k == "betas":
|
| 247 |
+
loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
|
| 248 |
+
gt[:, :15].reshape(batch_size, -1),
|
| 249 |
+
has_gt)
|
| 250 |
+
# v1
|
| 251 |
+
loss_params[k+"_re"] = torch.sum(pred[has_gt.bool()] ** 2) + torch.sum(pred[has_gt.bool()] ** 2) * 0.5
|
| 252 |
+
# v2
|
| 253 |
+
# loss_params[k+"_re"] = torch.sum(pred ** 2)
|
| 254 |
+
elif k == "bone":
|
| 255 |
+
loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
|
| 256 |
+
gt.reshape(batch_size, -1),
|
| 257 |
+
has_gt)
|
| 258 |
+
# v1
|
| 259 |
+
loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pred, self.posebone_prior_loss.bone_mean, 1 - has_gt) + \
|
| 260 |
+
self.posebone_prior_loss.l2_loss(pred, self.posebone_prior_loss.bone_mean, has_gt) * 0.02
|
| 261 |
+
# v2
|
| 262 |
+
# loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pred, self.posebone_prior_loss.bone_mean, torch.zeros_like(has_gt))
|
| 263 |
+
elif k == "pose":
|
| 264 |
+
loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
|
| 265 |
+
gt[:, :24].reshape(batch_size, -1),
|
| 266 |
+
has_gt)
|
| 267 |
+
pose_axis_angle = matrix_to_axis_angle(pred)
|
| 268 |
+
# v1
|
| 269 |
+
loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pose_axis_angle.reshape(batch_size, -1), self.posebone_prior_loss.pose_mean, 1 - has_gt) + \
|
| 270 |
+
self.posebone_prior_loss.l2_loss(pose_axis_angle.reshape(batch_size, -1), self.posebone_prior_loss.pose_mean, has_gt) * 0.02
|
| 271 |
+
# v2
|
| 272 |
+
# loss_params[k+"_re"] = self.posebone_prior_loss.l2_loss(pose_axis_angle.reshape(batch_size, -1), self.posebone_prior_loss.pose_mean, torch.zeros_like(has_gt))
|
| 273 |
+
else:
|
| 274 |
+
loss_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
|
| 275 |
+
gt.reshape(batch_size, -1),
|
| 276 |
+
has_gt)
|
| 277 |
+
|
| 278 |
+
loss_config = self.cfg.LOSS_WEIGHTS.AVES
|
| 279 |
+
loss = loss_config['KEYPOINTS_3D'] * loss_keypoints_3d + \
|
| 280 |
+
loss_config['KEYPOINTS_2D'] * loss_keypoints_2d + \
|
| 281 |
+
sum([loss_params[k] * loss_config[k.upper()] for k in loss_params]) + \
|
| 282 |
+
loss_config['MASK'] * loss_mask
|
| 283 |
+
|
| 284 |
+
losses = dict(loss_aves=loss.detach(),
|
| 285 |
+
loss_aves_keypoints_2d=loss_keypoints_2d.detach(),
|
| 286 |
+
loss_aves_keypoints_3d=loss_keypoints_3d.detach(),
|
| 287 |
+
loss_aves_mask=loss_mask.detach(),
|
| 288 |
+
)
|
| 289 |
+
for k, v in loss_params.items():
|
| 290 |
+
losses['loss_aves_' + k] = v.detach()
|
| 291 |
+
|
| 292 |
+
return loss, losses
|
| 293 |
+
|
| 294 |
+
def compute_smal_loss(self, batch: Dict, output: Dict) -> torch.Tensor:
|
| 295 |
+
"""
|
| 296 |
+
Compute SMAL losses given the input batch and the regression output
|
| 297 |
+
Args:
|
| 298 |
+
batch (Dict): Dictionary containing batch data
|
| 299 |
+
output (Dict): Dictionary containing the regression output
|
| 300 |
+
Returns:
|
| 301 |
+
torch.Tensor : Total loss for current batch
|
| 302 |
+
"""
|
| 303 |
+
dataset_source = (batch["supercategory"] < 5)
|
| 304 |
+
|
| 305 |
+
pred_params = output['pred_params']
|
| 306 |
+
pred_keypoints_2d = output['pred_keypoints_2d']
|
| 307 |
+
pred_keypoints_3d = output['pred_keypoints_3d']
|
| 308 |
+
|
| 309 |
+
batch_size = pred_params['pose'].shape[0]
|
| 310 |
+
|
| 311 |
+
# Get annotations
|
| 312 |
+
gt_keypoints_2d = batch['keypoints_2d'][dataset_source]
|
| 313 |
+
gt_keypoints_3d = batch['keypoints_3d'][dataset_source]
|
| 314 |
+
gt_params = {k: v[dataset_source] for k,v in batch['smal_params'].items()}
|
| 315 |
+
has_params = {k: v[dataset_source] for k,v in batch['has_smal_params'].items()}
|
| 316 |
+
is_axis_angle = {k: v[dataset_source] for k,v in batch['smal_params_is_axis_angle'].items()}
|
| 317 |
+
|
| 318 |
+
# Compute 3D keypoint loss
|
| 319 |
+
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
|
| 320 |
+
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
|
| 321 |
+
|
| 322 |
+
# Compute loss on SMAL parameters
|
| 323 |
+
loss_smal_params = {}
|
| 324 |
+
for k, pred in pred_params.items():
|
| 325 |
+
gt = gt_params[k].view(batch_size, -1)
|
| 326 |
+
if is_axis_angle[k].all():
|
| 327 |
+
gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
|
| 328 |
+
has_gt = has_params[k]
|
| 329 |
+
if k == "betas":
|
| 330 |
+
loss_smal_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
|
| 331 |
+
gt.reshape(batch_size, -1),
|
| 332 |
+
has_gt) + \
|
| 333 |
+
self.shape_prior_loss(pred, batch["category"][dataset_source], has_gt)
|
| 334 |
+
elif k == "bone":
|
| 335 |
+
continue
|
| 336 |
+
else:
|
| 337 |
+
loss_smal_params[k] = self.parameter_loss(pred.reshape(batch_size, -1),
|
| 338 |
+
gt.reshape(batch_size, -1),
|
| 339 |
+
has_gt) + \
|
| 340 |
+
self.pose_prior_loss(torch.cat((pred_params["global_orient"],
|
| 341 |
+
pred_params["pose"]),
|
| 342 |
+
dim=1), has_gt) / 2.
|
| 343 |
+
|
| 344 |
+
loss_config = self.cfg.LOSS_WEIGHTS.SMAL
|
| 345 |
+
loss = loss_config['KEYPOINTS_3D'] * loss_keypoints_3d + \
|
| 346 |
+
loss_config['KEYPOINTS_2D'] * loss_keypoints_2d + \
|
| 347 |
+
sum([loss_smal_params[k] * loss_config[k.upper()] for k in loss_smal_params])
|
| 348 |
+
|
| 349 |
+
losses = dict(loss_smal=loss.detach(),
|
| 350 |
+
loss_smal_keypoints_2d=loss_keypoints_2d.detach(),
|
| 351 |
+
loss_smal_keypoints_3d=loss_keypoints_3d.detach(),
|
| 352 |
+
)
|
| 353 |
+
for k, v in loss_smal_params.items():
|
| 354 |
+
losses['loss_smal_' + k] = v.detach()
|
| 355 |
+
|
| 356 |
+
return loss, losses
|
| 357 |
+
|
| 358 |
+
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
|
| 359 |
+
"""
|
| 360 |
+
Compute losses given the input batch and the regression output
|
| 361 |
+
Args:
|
| 362 |
+
batch (Dict): Dictionary containing batch data
|
| 363 |
+
output (Dict): Dictionary containing the regression output
|
| 364 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 365 |
+
Returns:
|
| 366 |
+
torch.Tensor : Total loss for current batch
|
| 367 |
+
"""
|
| 368 |
+
x = batch['img']
|
| 369 |
+
device, dtype = x.device, x.dtype
|
| 370 |
+
if 'aves_output' in output:
|
| 371 |
+
loss_aves, losses_aves = self.compute_aves_loss(batch, output['aves_output'])
|
| 372 |
+
else:
|
| 373 |
+
loss_aves, losses_aves = torch.tensor(0.0, device=device, dtype=dtype), {}
|
| 374 |
+
if 'smal_output' in output:
|
| 375 |
+
loss_smal, losses_smal = self.compute_smal_loss(batch, output['smal_output'])
|
| 376 |
+
else:
|
| 377 |
+
loss_smal, losses_smal = torch.tensor(0.0, device=device, dtype=dtype), {}
|
| 378 |
+
loss_supcon = self.supcon_loss(output['cls_feats'], labels=batch['category']) if self.cfg.MODEL.BACKBONE.get("USE_CLS", False) \
|
| 379 |
+
else torch.tensor(0.0, device=device, dtype=dtype)
|
| 380 |
+
loss = loss_aves + loss_smal + loss_supcon * self.cfg.LOSS_WEIGHTS['SUPCON']
|
| 381 |
+
|
| 382 |
+
# Saving loss
|
| 383 |
+
losses = {}
|
| 384 |
+
losses['loss'] = loss.detach()
|
| 385 |
+
losses['loss_supcon'] = loss_supcon.detach()
|
| 386 |
+
for k, v in losses_aves.items():
|
| 387 |
+
losses[k] = v.detach()
|
| 388 |
+
for k, v in losses_smal.items():
|
| 389 |
+
losses[k] = v.detach()
|
| 390 |
+
output['losses'] = losses
|
| 391 |
+
return loss
|
| 392 |
+
|
| 393 |
+
# Tensoroboard logging should run from first rank only
|
| 394 |
+
@pl.utilities.rank_zero.rank_zero_only
|
| 395 |
+
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True,
|
| 396 |
+
write_to_summary_writer: bool = True) -> None:
|
| 397 |
+
"""
|
| 398 |
+
Log results to Tensorboard
|
| 399 |
+
Args:
|
| 400 |
+
batch (Dict): Dictionary containing batch data
|
| 401 |
+
output (Dict): Dictionary containing the regression output
|
| 402 |
+
step_count (int): Global training step count
|
| 403 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
mode = 'train' if train else 'val'
|
| 407 |
+
batch_size = batch['keypoints_2d'].shape[0]
|
| 408 |
+
images = batch['img']
|
| 409 |
+
masks = batch['mask']
|
| 410 |
+
# mul std then add mean
|
| 411 |
+
images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1))
|
| 412 |
+
images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1))
|
| 413 |
+
masks = masks.unsqueeze(1).repeat(1, 3, 1, 1)
|
| 414 |
+
|
| 415 |
+
gt_keypoints_2d = batch['keypoints_2d']
|
| 416 |
+
losses = output['losses']
|
| 417 |
+
if write_to_summary_writer:
|
| 418 |
+
summary_writer = self.logger.experiment
|
| 419 |
+
for loss_name, val in losses.items():
|
| 420 |
+
summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count)
|
| 421 |
+
if train is False:
|
| 422 |
+
for metric_name, val in output['metric'].items():
|
| 423 |
+
summary_writer.add_scalar(mode + '/' + metric_name, val, step_count)
|
| 424 |
+
|
| 425 |
+
rend_imgs = []
|
| 426 |
+
num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
|
| 427 |
+
dataset_source = (batch["supercategory"] < 5)[:num_images] # bird for index 0
|
| 428 |
+
|
| 429 |
+
num_aves = (num_images - dataset_source[:num_images].sum()).item()
|
| 430 |
+
if num_aves:
|
| 431 |
+
rend_imgs_aves = self.aves_mesh_renderer.visualize_tensorboard( output['aves_output']['pred_vertices'][:num_aves].detach().cpu().numpy(),
|
| 432 |
+
output['aves_output']['pred_cam_t'][:num_aves].detach().cpu().numpy(),
|
| 433 |
+
images[:num_images][~dataset_source].cpu().numpy(),
|
| 434 |
+
self.cfg.AVES.get("FOCAL_LENGTH", 2167),
|
| 435 |
+
output['aves_output']['pred_keypoints_2d'][:num_aves].detach().cpu().numpy(),
|
| 436 |
+
gt_keypoints_2d[:num_images][~dataset_source][:, :18].cpu().numpy(),
|
| 437 |
+
)
|
| 438 |
+
rend_imgs.extend(rend_imgs_aves)
|
| 439 |
+
|
| 440 |
+
num_smal = dataset_source[:num_images].sum().item()
|
| 441 |
+
if num_smal:
|
| 442 |
+
rend_imgs_smal = self.smal_mesh_renderer.visualize_tensorboard( output['smal_output']['pred_vertices'][:num_smal].detach().cpu().numpy(),
|
| 443 |
+
output['smal_output']['pred_cam_t'][:num_smal].detach().cpu().numpy(),
|
| 444 |
+
images[:num_images][dataset_source].cpu().numpy(),
|
| 445 |
+
self.cfg.SMAL.get("FOCAL_LENGTH", 1000),
|
| 446 |
+
output['smal_output']['pred_keypoints_2d'][:num_smal].detach().cpu().numpy(),
|
| 447 |
+
gt_keypoints_2d[:num_images][dataset_source].cpu().numpy(),
|
| 448 |
+
)
|
| 449 |
+
rend_imgs.extend(rend_imgs_smal)
|
| 450 |
+
|
| 451 |
+
rend_imgs = make_grid(rend_imgs, nrow=5, padding=2)
|
| 452 |
+
if write_to_summary_writer:
|
| 453 |
+
summary_writer.add_image('%s/predictions' % mode, rend_imgs, step_count)
|
| 454 |
+
|
| 455 |
+
return rend_imgs
|
| 456 |
+
|
| 457 |
+
def forward(self, batch: Dict) -> Dict:
|
| 458 |
+
"""
|
| 459 |
+
Run a forward step of the network in val mode
|
| 460 |
+
Args:
|
| 461 |
+
batch (Dict): Dictionary containing batch data
|
| 462 |
+
Returns:
|
| 463 |
+
Dict: Dictionary containing the regression output
|
| 464 |
+
"""
|
| 465 |
+
return self.forward_step(batch, train=False)
|
| 466 |
+
|
| 467 |
+
def training_step(self, batch: Dict) -> Dict:
|
| 468 |
+
"""
|
| 469 |
+
Run a full training step
|
| 470 |
+
Args:
|
| 471 |
+
batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d',
|
| 472 |
+
'aves_params', 'aves_params_is_axis_angle', 'focal_length'}
|
| 473 |
+
Returns:
|
| 474 |
+
Dict: Dictionary containing regression output.
|
| 475 |
+
"""
|
| 476 |
+
batch = batch['img']
|
| 477 |
+
optimizer = self.optimizers(use_pl_optimizer=True)
|
| 478 |
+
|
| 479 |
+
batch_size = batch['img'].shape[0]
|
| 480 |
+
output = self.forward_step(batch, train=True)
|
| 481 |
+
if self.cfg.get('UPDATE_GT_SPIN', False):
|
| 482 |
+
self.update_batch_gt_spin(batch, output)
|
| 483 |
+
loss = self.compute_loss(batch, output, train=True)
|
| 484 |
+
|
| 485 |
+
# Error if Nan
|
| 486 |
+
if torch.isnan(loss):
|
| 487 |
+
raise ValueError('Loss is NaN')
|
| 488 |
+
|
| 489 |
+
optimizer.zero_grad()
|
| 490 |
+
self.manual_backward(loss)
|
| 491 |
+
# Clip gradient
|
| 492 |
+
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
|
| 493 |
+
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL,
|
| 494 |
+
error_if_nonfinite=True)
|
| 495 |
+
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
| 496 |
+
|
| 497 |
+
optimizer.step()
|
| 498 |
+
if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
|
| 499 |
+
self.tensorboard_logging(batch, output, self.global_step, train=True)
|
| 500 |
+
|
| 501 |
+
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False,
|
| 502 |
+
batch_size=batch_size, sync_dist=True)
|
| 503 |
+
|
| 504 |
+
return output
|
| 505 |
+
|
| 506 |
+
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
|
| 507 |
+
pass
|
| 508 |
+
|
amr/models/aves_warapper.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from smplx.utils import ModelOutput
|
| 5 |
+
from typing import Optional, NewType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
Tensor = NewType('Tensor', torch.Tensor)
|
| 9 |
+
@dataclass
|
| 10 |
+
class AVESOutput(ModelOutput):
|
| 11 |
+
betas: Optional[Tensor] = None
|
| 12 |
+
pose: Optional[Tensor] = None
|
| 13 |
+
bone: Optional[Tensor] = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LBS(torch.nn.Module):
|
| 17 |
+
'''
|
| 18 |
+
Implementation of linear blend skinning, with additional bone and scale
|
| 19 |
+
Input:
|
| 20 |
+
V (BN, V, 3): vertices to pose and shape
|
| 21 |
+
pose (BN, J, 3, 3) or (BN, J, 3): pose in rot or axis-angle
|
| 22 |
+
bone (BN, K): allow for direct change of relative joint distances
|
| 23 |
+
scale (1): scale the whole kinematic tree
|
| 24 |
+
'''
|
| 25 |
+
def __init__(self, J, parents, weights):
|
| 26 |
+
super(LBS, self).__init__()
|
| 27 |
+
self.n_joints = J.shape[1]
|
| 28 |
+
self.register_buffer('h_joints', F.pad(J.unsqueeze(-1), [0,0,0,1], value=0))
|
| 29 |
+
self.register_buffer('kin_tree', torch.cat([J[:,[0], :], J[:, 1:]-J[:, parents[1:]]], dim=1).unsqueeze(-1))
|
| 30 |
+
|
| 31 |
+
self.register_buffer('parents', parents)
|
| 32 |
+
self.register_buffer('weights', weights[None].float())
|
| 33 |
+
|
| 34 |
+
def __call__(self, V, pose, bone, scale, to_rotmats=False):
|
| 35 |
+
batch_size = len(V)
|
| 36 |
+
device = pose.device
|
| 37 |
+
V = F.pad(V.unsqueeze(-1), [0,0,0,1], value=1)
|
| 38 |
+
kin_tree = (scale*self.kin_tree) * bone[:, :, None, None]
|
| 39 |
+
pose = pose.view([batch_size, -1, 3, 3])
|
| 40 |
+
T = torch.zeros([batch_size, self.n_joints, 4, 4]).float().to(device)
|
| 41 |
+
T[:, :, -1, -1] = 1
|
| 42 |
+
T[:, :, :3, :] = torch.cat([pose, kin_tree], dim=-1)
|
| 43 |
+
T_rel = [T[:, 0]]
|
| 44 |
+
for i in range(1, self.n_joints):
|
| 45 |
+
T_rel.append(T_rel[self.parents[i]] @ T[:, i])
|
| 46 |
+
T_rel = torch.stack(T_rel, dim=1)
|
| 47 |
+
T_rel[:,:,:,[-1]] -= T_rel.clone() @ (self.h_joints*scale)
|
| 48 |
+
T_ = self.weights @ T_rel.view(batch_size, self.n_joints, -1)
|
| 49 |
+
T_ = T_.view(batch_size, -1, 4, 4)
|
| 50 |
+
V = T_ @ V
|
| 51 |
+
return V[:, :, :3, 0]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class AVES(torch.nn.Module):
|
| 55 |
+
def __init__(self, **kwargs):
|
| 56 |
+
super(AVES, self).__init__()
|
| 57 |
+
# kinematic tree, and map to keypoints from vertices
|
| 58 |
+
self.register_buffer('kintree_table', kwargs['kintree_table'])
|
| 59 |
+
self.register_buffer('parents', kwargs['kintree_table'][0])
|
| 60 |
+
self.register_buffer('weights', kwargs['weights'])
|
| 61 |
+
self.register_buffer('vert2kpt', kwargs['vert2kpt'])
|
| 62 |
+
self.register_buffer('face', kwargs['F'])
|
| 63 |
+
|
| 64 |
+
# mean shape and default joints
|
| 65 |
+
rot = torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], dtype=torch.float32)
|
| 66 |
+
rot = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float32) @ rot
|
| 67 |
+
# rot = torch.eye(3, dtype=torch.float32)
|
| 68 |
+
V = (rot @ kwargs['V'].T).T.unsqueeze(0)
|
| 69 |
+
J = (rot @ kwargs['J'].T).T.unsqueeze(0)
|
| 70 |
+
self.register_buffer('V', V)
|
| 71 |
+
self.register_buffer('J', J)
|
| 72 |
+
self.LBS = LBS(self.J, self.parents, self.weights)
|
| 73 |
+
|
| 74 |
+
# pose and bone prior
|
| 75 |
+
self.register_buffer('p_m', kwargs['pose_mean'])
|
| 76 |
+
self.register_buffer('b_m', kwargs['bone_mean'])
|
| 77 |
+
self.register_buffer('p_cov', kwargs['pose_cov'])
|
| 78 |
+
self.register_buffer('b_cov', kwargs['bone_cov'])
|
| 79 |
+
|
| 80 |
+
# standardized blend shape basis
|
| 81 |
+
B = kwargs['Beta']
|
| 82 |
+
sigma = kwargs['Beta_sigma']
|
| 83 |
+
B = B * sigma[:,None,None]
|
| 84 |
+
self.register_buffer('B', B)
|
| 85 |
+
|
| 86 |
+
# PCA coefficient that is optimized to match the original template shape
|
| 87 |
+
### so in the __call__ funciton, if beta is set to self.beta_original,
|
| 88 |
+
### it will return the template shape from ECCV2020 (marcbadger/avian-mesh).
|
| 89 |
+
self.register_buffer('beta_original', kwargs['beta_original'])
|
| 90 |
+
|
| 91 |
+
def __call__(self, global_orient, pose, bone, transl=None,
|
| 92 |
+
scale=1, betas=None, pose2rot=False, **kwargs):
|
| 93 |
+
'''
|
| 94 |
+
Input:
|
| 95 |
+
global_pose [bn, 3] tensor for batched global_pose on root joint
|
| 96 |
+
body_pose [bn, 72] tensor for batched body pose
|
| 97 |
+
bone_length [bn, 24] tensor for bone length; the bone variable
|
| 98 |
+
captures non-rigid joint articulation in this model
|
| 99 |
+
|
| 100 |
+
beta [bn, 15] shape PCA coefficients
|
| 101 |
+
If beta is None, it will return the mean shape
|
| 102 |
+
If beta is self.beta_original, it will return the orignial tempalte shape
|
| 103 |
+
|
| 104 |
+
'''
|
| 105 |
+
device = global_orient.device
|
| 106 |
+
batch_size = global_orient.shape[0]
|
| 107 |
+
V = self.V.repeat([batch_size, 1, 1]) * scale
|
| 108 |
+
J = self.J.repeat([batch_size, 1, 1]) * scale
|
| 109 |
+
|
| 110 |
+
# multi-bird shape space
|
| 111 |
+
if betas is not None:
|
| 112 |
+
V = V + torch.einsum('bk, kmn->bmn', betas, self.B)
|
| 113 |
+
|
| 114 |
+
# concatenate bone and pose
|
| 115 |
+
bone = torch.cat([torch.ones([batch_size, 1]).to(device), bone], dim=1)
|
| 116 |
+
pose = torch.cat([global_orient, pose], dim=1)
|
| 117 |
+
|
| 118 |
+
# LBS
|
| 119 |
+
verts = self.LBS(V, pose, bone, scale, to_rotmats=pose2rot)
|
| 120 |
+
if transl is not None:
|
| 121 |
+
verts = verts + transl[:, None, :]
|
| 122 |
+
|
| 123 |
+
# Calculate 3d keypoint from new vertices resulted from pose
|
| 124 |
+
keypoints = torch.einsum('bni,kn->bki', verts, self.vert2kpt)
|
| 125 |
+
|
| 126 |
+
output = AVESOutput(
|
| 127 |
+
vertices=verts,
|
| 128 |
+
joints=keypoints,
|
| 129 |
+
betas=betas,
|
| 130 |
+
global_orient=global_orient,
|
| 131 |
+
pose=pose,
|
| 132 |
+
bone=bone,
|
| 133 |
+
transl=transl,
|
| 134 |
+
full_pose=None,
|
| 135 |
+
)
|
| 136 |
+
return output
|
amr/models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vit_moe import vithmoe
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torchvision
|
| 4 |
+
|
| 5 |
+
def create_backbone(cfg):
|
| 6 |
+
if cfg.MODEL.BACKBONE.TYPE == 'vithmoe':
|
| 7 |
+
return vithmoe(cfg)
|
| 8 |
+
else:
|
| 9 |
+
raise NotImplementedError('Backbone type is not implemented')
|
amr/models/backbones/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (489 Bytes). View file
|
|
|
amr/models/backbones/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
amr/models/backbones/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (401 Bytes). View file
|
|
|
amr/models/backbones/__pycache__/rope_deit.cpython-310.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
amr/models/backbones/__pycache__/vit.cpython-310.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|