SondosM commited on
Commit
92de416
·
verified ·
1 Parent(s): e20a58e

Delete WiLoR

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. WiLoR/README.md +0 -93
  2. WiLoR/assets/teaser.png +0 -3
  3. WiLoR/demo.py +0 -142
  4. WiLoR/demo_img/test1.jpg +0 -0
  5. WiLoR/demo_img/test2.png +0 -3
  6. WiLoR/demo_img/test3.jpg +0 -0
  7. WiLoR/demo_img/test4.jpg +0 -3
  8. WiLoR/demo_img/test5.jpeg +0 -3
  9. WiLoR/demo_img/test6.jpg +0 -3
  10. WiLoR/demo_img/test7.jpg +0 -0
  11. WiLoR/demo_img/test8.jpg +0 -3
  12. WiLoR/download_videos.py +0 -58
  13. WiLoR/gradio_demo.py +0 -192
  14. WiLoR/license.txt +0 -402
  15. WiLoR/mano_data/mano_mean_params.npz +0 -3
  16. WiLoR/pretrained_models/dataset_config.yaml +0 -58
  17. WiLoR/pretrained_models/model_config.yaml +0 -119
  18. WiLoR/requirements.txt +0 -20
  19. WiLoR/whim/Dataset_instructions.md +0 -31
  20. WiLoR/whim/test_video_ids.json +0 -1
  21. WiLoR/whim/train_video_ids.json +0 -0
  22. WiLoR/wilor/configs/__init__.py +0 -114
  23. WiLoR/wilor/configs/__pycache__/__init__.cpython-311.pyc +0 -0
  24. WiLoR/wilor/datasets/utils.py +0 -994
  25. WiLoR/wilor/datasets/vitdet_dataset.py +0 -95
  26. WiLoR/wilor/models/__init__.py +0 -36
  27. WiLoR/wilor/models/__pycache__/__init__.cpython-311.pyc +0 -0
  28. WiLoR/wilor/models/__pycache__/discriminator.cpython-311.pyc +0 -0
  29. WiLoR/wilor/models/__pycache__/losses.cpython-311.pyc +0 -0
  30. WiLoR/wilor/models/__pycache__/mano_wrapper.cpython-311.pyc +0 -0
  31. WiLoR/wilor/models/__pycache__/wilor.cpython-311.pyc +0 -0
  32. WiLoR/wilor/models/backbones/__init__.py +0 -17
  33. WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
  34. WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-311.pyc +0 -0
  35. WiLoR/wilor/models/backbones/__pycache__/vit.cpython-310.pyc +0 -0
  36. WiLoR/wilor/models/backbones/__pycache__/vit.cpython-311.pyc +0 -0
  37. WiLoR/wilor/models/backbones/vit.py +0 -410
  38. WiLoR/wilor/models/discriminator.py +0 -98
  39. WiLoR/wilor/models/heads/__init__.py +0 -1
  40. WiLoR/wilor/models/heads/__pycache__/__init__.cpython-310.pyc +0 -0
  41. WiLoR/wilor/models/heads/__pycache__/__init__.cpython-311.pyc +0 -0
  42. WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-310.pyc +0 -0
  43. WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-311.pyc +0 -0
  44. WiLoR/wilor/models/heads/refinement_net.py +0 -204
  45. WiLoR/wilor/models/losses.py +0 -92
  46. WiLoR/wilor/models/mano_wrapper.py +0 -40
  47. WiLoR/wilor/models/wilor.py +0 -376
  48. WiLoR/wilor/utils/__init__.py +0 -25
  49. WiLoR/wilor/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  50. WiLoR/wilor/utils/__pycache__/geometry.cpython-311.pyc +0 -0
WiLoR/README.md DELETED
@@ -1,93 +0,0 @@
1
- <div align="center">
2
-
3
- # WiLoR: End-to-end 3D hand localization and reconstruction in-the-wild
4
-
5
- [Rolandos Alexandros Potamias](https://rolpotamias.github.io)<sup>1</sup> &emsp; [Jinglei Zhang]()<sup>2</sup> &emsp; [Jiankang Deng](https://jiankangdeng.github.io/)<sup>1</sup> &emsp; [Stefanos Zafeiriou](https://www.imperial.ac.uk/people/s.zafeiriou)<sup>1</sup>
6
-
7
- <sup>1</sup>Imperial College London, UK <br>
8
- <sup>2</sup>Shanghai Jiao Tong University, China
9
-
10
- <font color="blue"><strong>CVPR 2025</strong></font>
11
-
12
- <a href='https://rolpotamias.github.io/WiLoR/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
13
- <a href='https://arxiv.org/abs/2409.12259'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
14
- <a href='https://huggingface.co/spaces/rolpotamias/WiLoR'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-green'></a>
15
- <a href='https://colab.research.google.com/drive/1bNnYFECmJbbvCNZAKtQcxJGxf0DZppsB?usp=sharing'><img src='https://colab.research.google.com/assets/colab-badge.svg'></a>
16
- </div>
17
-
18
- <div align="center">
19
-
20
- [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/wilor-end-to-end-3d-hand-localization-and/3d-hand-pose-estimation-on-freihand)](https://paperswithcode.com/sota/3d-hand-pose-estimation-on-freihand?p=wilor-end-to-end-3d-hand-localization-and)
21
- [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/wilor-end-to-end-3d-hand-localization-and/3d-hand-pose-estimation-on-ho-3d)](https://paperswithcode.com/sota/3d-hand-pose-estimation-on-ho-3d?p=wilor-end-to-end-3d-hand-localization-and)
22
-
23
- </div>
24
-
25
- This is the official implementation of **[WiLoR](https://rolpotamias.github.io/WiLoR/)**, an state-of-the-art hand localization and reconstruction model:
26
-
27
- ![teaser](assets/teaser.png)
28
-
29
- ## Installation
30
- ### [Update] Quick Installation
31
- Thanks to [@warmshao](https://github.com/warmshao) WiLoR can now be installed using a single pip command:
32
- ```
33
- pip install git+https://github.com/warmshao/WiLoR-mini
34
- ```
35
- Please head to [WiLoR-mini](https://github.com/warmshao/WiLoR-mini) for additional details.
36
-
37
- **Note:** the above code is a simplified version of WiLoR and can be used for demo only.
38
- If you wish to use WiLoR for other tasks it is suggested to follow the original installation instructued bellow:
39
- ### Original Installation
40
- ```
41
- git clone --recursive https://github.com/rolpotamias/WiLoR.git
42
- cd WiLoR
43
- ```
44
-
45
- The code has been tested with PyTorch 2.0.0 and CUDA 11.7. It is suggested to use an anaconda environment to install the the required dependencies:
46
- ```bash
47
- conda create --name wilor python=3.10
48
- conda activate wilor
49
-
50
- pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117
51
- # Install requirements
52
- pip install -r requirements.txt
53
- ```
54
- Download the pretrained models using:
55
- ```bash
56
- wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/detector.pt -P ./pretrained_models/
57
- wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/wilor_final.ckpt -P ./pretrained_models/
58
- ```
59
- It is also required to download MANO model from [MANO website](https://mano.is.tue.mpg.de).
60
- Create an account by clicking Sign Up and download the models (mano_v*_*.zip). Unzip and place the right hand model `MANO_RIGHT.pkl` under the `mano_data/` folder.
61
- Note that MANO model falls under the [MANO license](https://mano.is.tue.mpg.de/license.html).
62
- ## Demo
63
- ```bash
64
- python demo.py --img_folder demo_img --out_folder demo_out --save_mesh
65
- ```
66
- ## Start a local gradio demo
67
- You can start a local demo for inference by running:
68
- ```bash
69
- python gradio_demo.py
70
- ```
71
- ## WHIM Dataset
72
- To download WHIM dataset please follow the instructions [here](./whim/Dataset_instructions.md)
73
-
74
- ## Acknowledgements
75
- Parts of the code are taken or adapted from the following repos:
76
- - [HaMeR](https://github.com/geopavlakos/hamer/)
77
- - [Ultralytics](https://github.com/ultralytics/ultralytics)
78
-
79
- ## License
80
- WiLoR models fall under the [CC-BY-NC--ND License](./license.txt). This repository depends also on [Ultralytics library](https://github.com/ultralytics/ultralytics) and [MANO Model](https://mano.is.tue.mpg.de/license.html), which are fall under their own licenses. By using this repository, you must also comply with the terms of these external licenses.
81
- ## Citing
82
- If you find WiLoR useful for your research, please consider citing our paper:
83
-
84
- ```bibtex
85
- @misc{potamias2024wilor,
86
- title={WiLoR: End-to-end 3D Hand Localization and Reconstruction in-the-wild},
87
- author={Rolandos Alexandros Potamias and Jinglei Zhang and Jiankang Deng and Stefanos Zafeiriou},
88
- year={2024},
89
- eprint={2409.12259},
90
- archivePrefix={arXiv},
91
- primaryClass={cs.CV}
92
- }
93
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/assets/teaser.png DELETED

Git LFS Details

  • SHA256: d5f07ada2f470af0619716c0ce4f60d9dfd3da1673d06c28c97d85abb84eadc0
  • Pointer size: 132 Bytes
  • Size of remote file: 9.21 MB
WiLoR/demo.py DELETED
@@ -1,142 +0,0 @@
1
- from pathlib import Path
2
- import torch
3
- import argparse
4
- import os
5
- import cv2
6
- import numpy as np
7
- import json
8
- from typing import Dict, Optional
9
-
10
- from wilor.models import WiLoR, load_wilor
11
- from wilor.utils import recursive_to
12
- from wilor.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
13
- from wilor.utils.renderer import Renderer, cam_crop_to_full
14
- from ultralytics import YOLO
15
- LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
16
-
17
- def main():
18
- parser = argparse.ArgumentParser(description='WiLoR demo code')
19
- parser.add_argument('--img_folder', type=str, default='images', help='Folder with input images')
20
- parser.add_argument('--out_folder', type=str, default='out_demo', help='Output folder to save rendered results')
21
- parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='If set, save meshes to disk also')
22
- parser.add_argument('--rescale_factor', type=float, default=2.0, help='Factor for padding the bbox')
23
- parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg'], help='List of file extensions to consider')
24
-
25
- args = parser.parse_args()
26
-
27
- # Download and load checkpoints
28
- model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
29
- detector = YOLO('./pretrained_models/detector.pt')
30
- # Setup the renderer
31
- renderer = Renderer(model_cfg, faces=model.mano.faces)
32
- renderer_side = Renderer(model_cfg, faces=model.mano.faces)
33
-
34
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
35
- model = model.to(device)
36
- detector = detector.to(device)
37
- model.eval()
38
-
39
- # Make output directory if it does not exist
40
- os.makedirs(args.out_folder, exist_ok=True)
41
-
42
- # Get all demo images ends with .jpg or .png
43
- img_paths = [img for end in args.file_type for img in Path(args.img_folder).glob(end)]
44
- # Iterate over all images in folder
45
- for img_path in img_paths:
46
- img_cv2 = cv2.imread(str(img_path))
47
- detections = detector(img_cv2, conf = 0.3, verbose=False)[0]
48
- bboxes = []
49
- is_right = []
50
- for det in detections:
51
- Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
52
- is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
53
- bboxes.append(Bbox[:4].tolist())
54
-
55
- if len(bboxes) == 0:
56
- continue
57
- boxes = np.stack(bboxes)
58
- right = np.stack(is_right)
59
- dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=args.rescale_factor)
60
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
61
-
62
- all_verts = []
63
- all_cam_t = []
64
- all_right = []
65
- all_joints= []
66
- all_kpts = []
67
-
68
- for batch in dataloader:
69
- batch = recursive_to(batch, device)
70
-
71
- with torch.no_grad():
72
- out = model(batch)
73
-
74
- multiplier = (2*batch['right']-1)
75
- pred_cam = out['pred_cam']
76
- pred_cam[:,1] = multiplier*pred_cam[:,1]
77
- box_center = batch["box_center"].float()
78
- box_size = batch["box_size"].float()
79
- img_size = batch["img_size"].float()
80
- scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
81
- pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy()
82
-
83
-
84
- # Render the result
85
- batch_size = batch['img'].shape[0]
86
- for n in range(batch_size):
87
- # Get filename from path img_path
88
- img_fn, _ = os.path.splitext(os.path.basename(img_path))
89
-
90
- verts = out['pred_vertices'][n].detach().cpu().numpy()
91
- joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
92
-
93
- is_right = batch['right'][n].cpu().numpy()
94
- verts[:,0] = (2*is_right-1)*verts[:,0]
95
- joints[:,0] = (2*is_right-1)*joints[:,0]
96
- cam_t = pred_cam_t_full[n]
97
- kpts_2d = project_full_img(verts, cam_t, scaled_focal_length, img_size[n])
98
-
99
- all_verts.append(verts)
100
- all_cam_t.append(cam_t)
101
- all_right.append(is_right)
102
- all_joints.append(joints)
103
- all_kpts.append(kpts_2d)
104
-
105
-
106
- # Save all meshes to disk
107
- if args.save_mesh:
108
- camera_translation = cam_t.copy()
109
- tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_PURPLE, is_right=is_right)
110
- tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{n}.obj'))
111
-
112
- # Render front view
113
- if len(all_verts) > 0:
114
- misc_args = dict(
115
- mesh_base_color=LIGHT_PURPLE,
116
- scene_bg_color=(1, 1, 1),
117
- focal_length=scaled_focal_length,
118
- )
119
- cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=img_size[n], is_right=all_right, **misc_args)
120
-
121
- # Overlay image
122
- input_img = img_cv2.astype(np.float32)[:,:,::-1]/255.0
123
- input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
124
- input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
125
-
126
- cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}.jpg'), 255*input_img_overlay[:, :, ::-1])
127
-
128
- def project_full_img(points, cam_trans, focal_length, img_res):
129
- camera_center = [img_res[0] / 2., img_res[1] / 2.]
130
- K = torch.eye(3)
131
- K[0,0] = focal_length
132
- K[1,1] = focal_length
133
- K[0,2] = camera_center[0]
134
- K[1,2] = camera_center[1]
135
- points = points + cam_trans
136
- points = points / points[..., -1:]
137
-
138
- V_2d = (K @ points.T).T
139
- return V_2d[..., :-1]
140
-
141
- if __name__ == '__main__':
142
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/demo_img/test1.jpg DELETED
Binary file (50 kB)
 
WiLoR/demo_img/test2.png DELETED

Git LFS Details

  • SHA256: 589f5d12593acbcbcb9ec07b288b04f6d7e70542e1312ceee3ea992ba0f41ff9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
WiLoR/demo_img/test3.jpg DELETED
Binary file (34.4 kB)
 
WiLoR/demo_img/test4.jpg DELETED

Git LFS Details

  • SHA256: efb16543caa936aa671ad1cb28ca2c6129ba8cba58d08476ed9538fd12de9265
  • Pointer size: 131 Bytes
  • Size of remote file: 315 kB
WiLoR/demo_img/test5.jpeg DELETED

Git LFS Details

  • SHA256: 84d161aa4f1a335ec3971c5d050338e7c13b9e3c90231c0de7e677094a172eae
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
WiLoR/demo_img/test6.jpg DELETED

Git LFS Details

  • SHA256: 617a3a3d04a1e17e4285dab5bca2003080923df66953df93c85ddfdaa383e8f5
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
WiLoR/demo_img/test7.jpg DELETED
Binary file (60.6 kB)
 
WiLoR/demo_img/test8.jpg DELETED

Git LFS Details

  • SHA256: 886ef1a8981bef175707353b2adea60168657a926c1dd5a95789c4907d881907
  • Pointer size: 131 Bytes
  • Size of remote file: 398 kB
WiLoR/download_videos.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- import json
3
- import numpy as np
4
- import argparse
5
- from pytubefix import YouTube
6
-
7
- parser = argparse.ArgumentParser()
8
-
9
- parser.add_argument("--root", type=str, help="Directory of WiLoR")
10
- parser.add_argument("--mode", type=str, choices=['train', 'test'], default= 'train', help="Train/Test set")
11
-
12
- args = parser.parse_args()
13
-
14
- with open(os.path.join(args.root, f'./whim/{args.mode}_video_ids.json')) as f:
15
- video_dict = json.load(f)
16
-
17
- Video_IDs = video_dict.keys()
18
- failed_IDs = []
19
- os.makedirs(os.path.join(args.root, 'Videos'), exist_ok=True)
20
-
21
- for Video_ID in Video_IDs:
22
- res = video_dict[Video_ID]['res'][0]
23
- try:
24
- YouTube('https://youtu.be/'+Video_ID).streams.filter(only_video=True,
25
- file_extension='mp4',
26
- res =f'{res}p'
27
- ).order_by('resolution').desc().first().download(
28
- output_path=os.path.join(args.root, 'Videos') ,
29
- filename = Video_ID +'.mp4')
30
- except:
31
- print(f'Failed {Video_ID}')
32
- failed_IDs.append(Video_ID)
33
- continue
34
-
35
-
36
- cap = cv2.VideoCapture(os.path.join(args.root, 'Videos', Video_ID + '.mp4'))
37
- if (cap.isOpened()== False):
38
- print(f"Error opening video stream {os.path.join(args.root, 'Videos', Video_ID + '.mp4')}")
39
-
40
- VIDEO_LEN = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
41
- length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
42
- fps = cap.get(cv2.CAP_PROP_FPS)
43
-
44
- fps_org = video_dict[Video_ID]['fps']
45
- fps_rate = round(fps / fps_org)
46
-
47
- all_frames = os.listdir(os.path.join(args.root, 'WHIM', args.mode, 'anno', Video_ID))
48
-
49
- for frame in all_frames:
50
- frame_gt = int(frame[:-4])
51
- frame_idx = (frame_gt * fps_rate)
52
-
53
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
54
- ret, img_cv2 = cap.read()
55
-
56
- cv2.imwrite(os.path.join(args.root, 'WHIM', args.mode, 'anno', Video_ID, frame +'.jpg' ), img_cv2.astype(np.float32))
57
-
58
- np.save(os.path.join(args.root, 'failed_videos.npy'), failed_IDs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/gradio_demo.py DELETED
@@ -1,192 +0,0 @@
1
- import os
2
- import sys
3
- os.environ["PYOPENGL_PLATFORM"] = "egl"
4
- os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
5
- # os.system('pip install /home/user/app/pyrender')
6
- # sys.path.append('/home/user/app/pyrender')
7
-
8
- import gradio as gr
9
- #import spaces
10
- import cv2
11
- import numpy as np
12
- import torch
13
- from ultralytics import YOLO
14
- from pathlib import Path
15
- import argparse
16
- import json
17
- from typing import Dict, Optional
18
-
19
- from wilor.models import WiLoR, load_wilor
20
- from wilor.utils import recursive_to
21
- from wilor.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
22
- from wilor.utils.renderer import Renderer, cam_crop_to_full
23
- device = torch.device('cpu') if torch.cuda.is_available() else torch.device('cuda')
24
-
25
- LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
26
-
27
- model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
28
- # Setup the renderer
29
- renderer = Renderer(model_cfg, faces=model.mano.faces)
30
- model = model.to(device)
31
- model.eval()
32
-
33
- detector = YOLO(f'./pretrained_models/detector.pt').to(device)
34
-
35
- def render_reconstruction(image, conf, IoU_threshold=0.3):
36
- input_img, num_dets, reconstructions = run_wilow_model(image, conf, IoU_threshold=0.5)
37
- if num_dets> 0:
38
- # Render front view
39
-
40
- misc_args = dict(
41
- mesh_base_color=LIGHT_PURPLE,
42
- scene_bg_color=(1, 1, 1),
43
- focal_length=reconstructions['focal'],
44
- )
45
-
46
- cam_view = renderer.render_rgba_multiple(reconstructions['verts'],
47
- cam_t=reconstructions['cam_t'],
48
- render_res=reconstructions['img_size'],
49
- is_right=reconstructions['right'], **misc_args)
50
-
51
- # Overlay image
52
-
53
- input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
54
- input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
55
-
56
- return input_img_overlay, f'{num_dets} hands detected'
57
- else:
58
- return input_img, f'{num_dets} hands detected'
59
-
60
- #@spaces.GPU()
61
- def run_wilow_model(image, conf, IoU_threshold=0.5):
62
- img_cv2 = image[...,::-1]
63
- img_vis = image.copy()
64
-
65
- detections = detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
66
-
67
- bboxes = []
68
- is_right = []
69
- for det in detections:
70
- Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
71
- Conf = det.boxes.conf.data.cpu().detach()[0].numpy().reshape(-1).astype(np.float16)
72
- Side = det.boxes.cls.data.cpu().detach()
73
- #Bbox[:2] -= np.int32(0.1 * Bbox[:2])
74
- #Bbox[2:] += np.int32(0.1 * Bbox[ 2:])
75
- is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
76
- bboxes.append(Bbox[:4].tolist())
77
-
78
- color = (255*0.208, 255*0.647 ,255*0.603 ) if Side==0. else (255*1, 255*0.78039, 255*0.2353)
79
- label = f'L - {Conf[0]:.3f}' if Side==0 else f'R - {Conf[0]:.3f}'
80
-
81
- cv2.rectangle(img_vis, (int(Bbox[0]), int(Bbox[1])), (int(Bbox[2]), int(Bbox[3])), color , 3)
82
- (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
83
- cv2.rectangle(img_vis, (int(Bbox[0]), int(Bbox[1]) - 20), (int(Bbox[0]) + w, int(Bbox[1])), color, -1)
84
- cv2.putText(img_vis, label, (int(Bbox[0]), int(Bbox[1]) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2)
85
-
86
- if len(bboxes) != 0:
87
- boxes = np.stack(bboxes)
88
- right = np.stack(is_right)
89
- dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=2.0 )
90
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)
91
-
92
- all_verts = []
93
- all_cam_t = []
94
- all_right = []
95
- all_joints= []
96
-
97
- for batch in dataloader:
98
- batch = recursive_to(batch, device)
99
-
100
- with torch.no_grad():
101
- out = model(batch)
102
-
103
- multiplier = (2*batch['right']-1)
104
- pred_cam = out['pred_cam']
105
- pred_cam[:,1] = multiplier*pred_cam[:,1]
106
- box_center = batch["box_center"].float()
107
- box_size = batch["box_size"].float()
108
- img_size = batch["img_size"].float()
109
- scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
110
- pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy()
111
-
112
-
113
- batch_size = batch['img'].shape[0]
114
- for n in range(batch_size):
115
-
116
- verts = out['pred_vertices'][n].detach().cpu().numpy()
117
- joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
118
-
119
- is_right = batch['right'][n].cpu().numpy()
120
- verts[:,0] = (2*is_right-1)*verts[:,0]
121
- joints[:,0] = (2*is_right-1)*joints[:,0]
122
-
123
- cam_t = pred_cam_t_full[n]
124
-
125
- all_verts.append(verts)
126
- all_cam_t.append(cam_t)
127
- all_right.append(is_right)
128
- all_joints.append(joints)
129
-
130
- reconstructions = {'verts': all_verts, 'cam_t': all_cam_t, 'right': all_right, 'img_size': img_size[n], 'focal': scaled_focal_length}
131
- return img_vis.astype(np.float32)/255.0, len(detections), reconstructions
132
- else:
133
- return img_vis.astype(np.float32)/255.0, len(detections), None
134
-
135
-
136
-
137
- header = ('''
138
- <div class="embed_hidden" style="text-align: center;">
139
- <h1> <b>WiLoR</b>: End-to-end 3D hand localization and reconstruction in-the-wild</h1>
140
- <h3>
141
- <a href="https://rolpotamias.github.io" target="_blank" rel="noopener noreferrer">Rolandos Alexandros Potamias</a><sup>1</sup>,
142
- <a href="" target="_blank" rel="noopener noreferrer">Jinglei Zhang</a><sup>2</sup>,
143
- <br>
144
- <a href="https://jiankangdeng.github.io/" target="_blank" rel="noopener noreferrer">Jiankang Deng</a><sup>1</sup>,
145
- <a href="https://wp.doc.ic.ac.uk/szafeiri/" target="_blank" rel="noopener noreferrer">Stefanos Zafeiriou</a><sup>1</sup>
146
- </h3>
147
- <h3>
148
- <sup>1</sup>Imperial College London;
149
- <sup>2</sup>Shanghai Jiao Tong University
150
- </h3>
151
- </div>
152
- <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
153
- <a href=''><img src='https://img.shields.io/badge/Arxiv-......-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
154
- <a href='https://rolpotamias.github.io/pdfs/WiLoR.pdf'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
155
- <a href='https://rolpotamias.github.io/WiLoR/'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
156
- <a href='https://github.com/rolpotamias/WiLoR'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
157
- ''')
158
-
159
-
160
- with gr.Blocks(title="WiLoR: End-to-end 3D hand localization and reconstruction in-the-wild", css=".gradio-container") as demo:
161
-
162
- gr.Markdown(header)
163
-
164
- with gr.Row():
165
- with gr.Column():
166
- input_image = gr.Image(label="Input image", type="numpy")
167
- threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold')
168
- #nms = gr.Slider(value=0.5, minimum=0.05, maximum=0.95, step=0.05, label='IoU NMS Threshold')
169
- submit = gr.Button("Submit", variant="primary")
170
-
171
-
172
- with gr.Column():
173
- reconstruction = gr.Image(label="Reconstructions", type="numpy")
174
- hands_detected = gr.Textbox(label="Hands Detected")
175
-
176
- submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])
177
-
178
- with gr.Row():
179
- example_images = gr.Examples([
180
-
181
- ['./demo_img/test1.jpg'],
182
- ['./demo_img/test2.png'],
183
- ['./demo_img/test3.jpg'],
184
- ['./demo_img/test4.jpg'],
185
- ['./demo_img/test5.jpeg'],
186
- ['./demo_img/test6.jpg'],
187
- ['./demo_img/test7.jpg'],
188
- ['./demo_img/test8.jpg'],
189
- ],
190
- inputs=input_image)
191
-
192
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/license.txt DELETED
@@ -1,402 +0,0 @@
1
- Attribution-NonCommercial-NoDerivatives 4.0 International
2
-
3
- =======================================================================
4
-
5
- Creative Commons Corporation ("Creative Commons") is not a law firm and
6
- does not provide legal services or legal advice. Distribution of
7
- Creative Commons public licenses does not create a lawyer-client or
8
- other relationship. Creative Commons makes its licenses and related
9
- information available on an "as-is" basis. Creative Commons gives no
10
- warranties regarding its licenses, any material licensed under their
11
- terms and conditions, or any related information. Creative Commons
12
- disclaims all liability for damages resulting from their use to the
13
- fullest extent possible.
14
-
15
- Using Creative Commons Public Licenses
16
-
17
- Creative Commons public licenses provide a standard set of terms and
18
- conditions that creators and other rights holders may use to share
19
- original works of authorship and other material subject to copyright
20
- and certain other rights specified in the public license below. The
21
- following considerations are for informational purposes only, are not
22
- exhaustive, and do not form part of our licenses.
23
-
24
- Considerations for licensors: Our public licenses are
25
- intended for use by those authorized to give the public
26
- permission to use material in ways otherwise restricted by
27
- copyright and certain other rights. Our licenses are
28
- irrevocable. Licensors should read and understand the terms
29
- and conditions of the license they choose before applying it.
30
- Licensors should also secure all rights necessary before
31
- applying our licenses so that the public can reuse the
32
- material as expected. Licensors should clearly mark any
33
- material not subject to the license. This includes other CC-
34
- licensed material, or material used under an exception or
35
- limitation to copyright. More considerations for licensors:
36
- wiki.creativecommons.org/Considerations_for_licensors
37
-
38
- Considerations for the public: By using one of our public
39
- licenses, a licensor grants the public permission to use the
40
- licensed material under specified terms and conditions. If
41
- the licensor's permission is not necessary for any reason--for
42
- example, because of any applicable exception or limitation to
43
- copyright--then that use is not regulated by the license. Our
44
- licenses grant only permissions under copyright and certain
45
- other rights that a licensor has authority to grant. Use of
46
- the licensed material may still be restricted for other
47
- reasons, including because others have copyright or other
48
- rights in the material. A licensor may make special requests,
49
- such as asking that all changes be marked or described.
50
- Although not required by our licenses, you are encouraged to
51
- respect those requests where reasonable. More considerations
52
- for the public:
53
- wiki.creativecommons.org/Considerations_for_licensees
54
-
55
- =======================================================================
56
-
57
- Creative Commons Attribution-NonCommercial-NoDerivatives 4.0
58
- International Public License
59
-
60
- By exercising the Licensed Rights (defined below), You accept and agree
61
- to be bound by the terms and conditions of this Creative Commons
62
- Attribution-NonCommercial-NoDerivatives 4.0 International Public
63
- License ("Public License"). To the extent this Public License may be
64
- interpreted as a contract, You are granted the Licensed Rights in
65
- consideration of Your acceptance of these terms and conditions, and the
66
- Licensor grants You such rights in consideration of benefits the
67
- Licensor receives from making the Licensed Material available under
68
- these terms and conditions.
69
-
70
-
71
- Section 1 -- Definitions.
72
-
73
- a. Adapted Material means material subject to Copyright and Similar
74
- Rights that is derived from or based upon the Licensed Material
75
- and in which the Licensed Material is translated, altered,
76
- arranged, transformed, or otherwise modified in a manner requiring
77
- permission under the Copyright and Similar Rights held by the
78
- Licensor. For purposes of this Public License, where the Licensed
79
- Material is a musical work, performance, or sound recording,
80
- Adapted Material is always produced where the Licensed Material is
81
- synched in timed relation with a moving image.
82
-
83
- b. Copyright and Similar Rights means copyright and/or similar rights
84
- closely related to copyright including, without limitation,
85
- performance, broadcast, sound recording, and Sui Generis Database
86
- Rights, without regard to how the rights are labeled or
87
- categorized. For purposes of this Public License, the rights
88
- specified in Section 2(b)(1)-(2) are not Copyright and Similar
89
- Rights.
90
-
91
- c. Effective Technological Measures means those measures that, in the
92
- absence of proper authority, may not be circumvented under laws
93
- fulfilling obligations under Article 11 of the WIPO Copyright
94
- Treaty adopted on December 20, 1996, and/or similar international
95
- agreements.
96
-
97
- d. Exceptions and Limitations means fair use, fair dealing, and/or
98
- any other exception or limitation to Copyright and Similar Rights
99
- that applies to Your use of the Licensed Material.
100
-
101
- e. Licensed Material means the artistic or literary work, database,
102
- or other material to which the Licensor applied this Public
103
- License.
104
-
105
- f. Licensed Rights means the rights granted to You subject to the
106
- terms and conditions of this Public License, which are limited to
107
- all Copyright and Similar Rights that apply to Your use of the
108
- Licensed Material and that the Licensor has authority to license.
109
-
110
- g. Licensor means the individual(s) or entity(ies) granting rights
111
- under this Public License.
112
-
113
- h. NonCommercial means not primarily intended for or directed towards
114
- commercial advantage or monetary compensation. For purposes of
115
- this Public License, the exchange of the Licensed Material for
116
- other material subject to Copyright and Similar Rights by digital
117
- file-sharing or similar means is NonCommercial provided there is
118
- no payment of monetary compensation in connection with the
119
- exchange.
120
-
121
- i. Share means to provide material to the public by any means or
122
- process that requires permission under the Licensed Rights, such
123
- as reproduction, public display, public performance, distribution,
124
- dissemination, communication, or importation, and to make material
125
- available to the public including in ways that members of the
126
- public may access the material from a place and at a time
127
- individually chosen by them.
128
-
129
- j. Sui Generis Database Rights means rights other than copyright
130
- resulting from Directive 96/9/EC of the European Parliament and of
131
- the Council of 11 March 1996 on the legal protection of databases,
132
- as amended and/or succeeded, as well as other essentially
133
- equivalent rights anywhere in the world.
134
-
135
- k. You means the individual or entity exercising the Licensed Rights
136
- under this Public License. Your has a corresponding meaning.
137
-
138
-
139
- Section 2 -- Scope.
140
-
141
- a. License grant.
142
-
143
- 1. Subject to the terms and conditions of this Public License,
144
- the Licensor hereby grants You a worldwide, royalty-free,
145
- non-sublicensable, non-exclusive, irrevocable license to
146
- exercise the Licensed Rights in the Licensed Material to:
147
-
148
- a. reproduce and Share the Licensed Material, in whole or
149
- in part, for NonCommercial purposes only; and
150
-
151
- b. produce and reproduce, but not Share, Adapted Material
152
- for NonCommercial purposes only.
153
-
154
- 2. Exceptions and Limitations. For the avoidance of doubt, where
155
- Exceptions and Limitations apply to Your use, this Public
156
- License does not apply, and You do not need to comply with
157
- its terms and conditions.
158
-
159
- 3. Term. The term of this Public License is specified in Section
160
- 6(a).
161
-
162
- 4. Media and formats; technical modifications allowed. The
163
- Licensor authorizes You to exercise the Licensed Rights in
164
- all media and formats whether now known or hereafter created,
165
- and to make technical modifications necessary to do so. The
166
- Licensor waives and/or agrees not to assert any right or
167
- authority to forbid You from making technical modifications
168
- necessary to exercise the Licensed Rights, including
169
- technical modifications necessary to circumvent Effective
170
- Technological Measures. For purposes of this Public License,
171
- simply making modifications authorized by this Section 2(a)
172
- (4) never produces Adapted Material.
173
-
174
- 5. Downstream recipients.
175
-
176
- a. Offer from the Licensor -- Licensed Material. Every
177
- recipient of the Licensed Material automatically
178
- receives an offer from the Licensor to exercise the
179
- Licensed Rights under the terms and conditions of this
180
- Public License.
181
-
182
- b. No downstream restrictions. You may not offer or impose
183
- any additional or different terms or conditions on, or
184
- apply any Effective Technological Measures to, the
185
- Licensed Material if doing so restricts exercise of the
186
- Licensed Rights by any recipient of the Licensed
187
- Material.
188
-
189
- 6. No endorsement. Nothing in this Public License constitutes or
190
- may be construed as permission to assert or imply that You
191
- are, or that Your use of the Licensed Material is, connected
192
- with, or sponsored, endorsed, or granted official status by,
193
- the Licensor or others designated to receive attribution as
194
- provided in Section 3(a)(1)(A)(i).
195
-
196
- b. Other rights.
197
-
198
- 1. Moral rights, such as the right of integrity, are not
199
- licensed under this Public License, nor are publicity,
200
- privacy, and/or other similar personality rights; however, to
201
- the extent possible, the Licensor waives and/or agrees not to
202
- assert any such rights held by the Licensor to the limited
203
- extent necessary to allow You to exercise the Licensed
204
- Rights, but not otherwise.
205
-
206
- 2. Patent and trademark rights are not licensed under this
207
- Public License.
208
-
209
- 3. To the extent possible, the Licensor waives any right to
210
- collect royalties from You for the exercise of the Licensed
211
- Rights, whether directly or through a collecting society
212
- under any voluntary or waivable statutory or compulsory
213
- licensing scheme. In all other cases the Licensor expressly
214
- reserves any right to collect such royalties, including when
215
- the Licensed Material is used other than for NonCommercial
216
- purposes.
217
-
218
-
219
- Section 3 -- License Conditions.
220
-
221
- Your exercise of the Licensed Rights is expressly made subject to the
222
- following conditions.
223
-
224
- a. Attribution.
225
-
226
- 1. If You Share the Licensed Material, You must:
227
-
228
- a. retain the following if it is supplied by the Licensor
229
- with the Licensed Material:
230
-
231
- i. identification of the creator(s) of the Licensed
232
- Material and any others designated to receive
233
- attribution, in any reasonable manner requested by
234
- the Licensor (including by pseudonym if
235
- designated);
236
-
237
- ii. a copyright notice;
238
-
239
- iii. a notice that refers to this Public License;
240
-
241
- iv. a notice that refers to the disclaimer of
242
- warranties;
243
-
244
- v. a URI or hyperlink to the Licensed Material to the
245
- extent reasonably practicable;
246
-
247
- b. indicate if You modified the Licensed Material and
248
- retain an indication of any previous modifications; and
249
-
250
- c. indicate the Licensed Material is licensed under this
251
- Public License, and include the text of, or the URI or
252
- hyperlink to, this Public License.
253
-
254
- For the avoidance of doubt, You do not have permission under
255
- this Public License to Share Adapted Material.
256
-
257
- 2. You may satisfy the conditions in Section 3(a)(1) in any
258
- reasonable manner based on the medium, means, and context in
259
- which You Share the Licensed Material. For example, it may be
260
- reasonable to satisfy the conditions by providing a URI or
261
- hyperlink to a resource that includes the required
262
- information.
263
-
264
- 3. If requested by the Licensor, You must remove any of the
265
- information required by Section 3(a)(1)(A) to the extent
266
- reasonably practicable.
267
-
268
-
269
- Section 4 -- Sui Generis Database Rights.
270
-
271
- Where the Licensed Rights include Sui Generis Database Rights that
272
- apply to Your use of the Licensed Material:
273
-
274
- a. for the avoidance of doubt, Section 2(a)(1) grants You the right
275
- to extract, reuse, reproduce, and Share all or a substantial
276
- portion of the contents of the database for NonCommercial purposes
277
- only and provided You do not Share Adapted Material;
278
-
279
- b. if You include all or a substantial portion of the database
280
- contents in a database in which You have Sui Generis Database
281
- Rights, then the database in which You have Sui Generis Database
282
- Rights (but not its individual contents) is Adapted Material; and
283
-
284
- c. You must comply with the conditions in Section 3(a) if You Share
285
- all or a substantial portion of the contents of the database.
286
-
287
- For the avoidance of doubt, this Section 4 supplements and does not
288
- replace Your obligations under this Public License where the Licensed
289
- Rights include other Copyright and Similar Rights.
290
-
291
-
292
- Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
-
294
- a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
- EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
- AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
- ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
- IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
- WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
- PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
- ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
- KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
- ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
-
305
- b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
- TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
- NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
- INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
- COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
- USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
- ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
- DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
- IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
-
315
- c. The disclaimer of warranties and limitation of liability provided
316
- above shall be interpreted in a manner that, to the extent
317
- possible, most closely approximates an absolute disclaimer and
318
- waiver of all liability.
319
-
320
-
321
- Section 6 -- Term and Termination.
322
-
323
- a. This Public License applies for the term of the Copyright and
324
- Similar Rights licensed here. However, if You fail to comply with
325
- this Public License, then Your rights under this Public License
326
- terminate automatically.
327
-
328
- b. Where Your right to use the Licensed Material has terminated under
329
- Section 6(a), it reinstates:
330
-
331
- 1. automatically as of the date the violation is cured, provided
332
- it is cured within 30 days of Your discovery of the
333
- violation; or
334
-
335
- 2. upon express reinstatement by the Licensor.
336
-
337
- For the avoidance of doubt, this Section 6(b) does not affect any
338
- right the Licensor may have to seek remedies for Your violations
339
- of this Public License.
340
-
341
- c. For the avoidance of doubt, the Licensor may also offer the
342
- Licensed Material under separate terms or conditions or stop
343
- distributing the Licensed Material at any time; however, doing so
344
- will not terminate this Public License.
345
-
346
- d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347
- License.
348
-
349
-
350
- Section 7 -- Other Terms and Conditions.
351
-
352
- a. The Licensor shall not be bound by any additional or different
353
- terms or conditions communicated by You unless expressly agreed.
354
-
355
- b. Any arrangements, understandings, or agreements regarding the
356
- Licensed Material not stated herein are separate from and
357
- independent of the terms and conditions of this Public License.
358
-
359
-
360
- Section 8 -- Interpretation.
361
-
362
- a. For the avoidance of doubt, this Public License does not, and
363
- shall not be interpreted to, reduce, limit, restrict, or impose
364
- conditions on any use of the Licensed Material that could lawfully
365
- be made without permission under this Public License.
366
-
367
- b. To the extent possible, if any provision of this Public License is
368
- deemed unenforceable, it shall be automatically reformed to the
369
- minimum extent necessary to make it enforceable. If the provision
370
- cannot be reformed, it shall be severed from this Public License
371
- without affecting the enforceability of the remaining terms and
372
- conditions.
373
-
374
- c. No term or condition of this Public License will be waived and no
375
- failure to comply consented to unless expressly agreed to by the
376
- Licensor.
377
-
378
- d. Nothing in this Public License constitutes or may be interpreted
379
- as a limitation upon, or waiver of, any privileges and immunities
380
- that apply to the Licensor or You, including from the legal
381
- processes of any jurisdiction or authority.
382
-
383
- =======================================================================
384
-
385
- Creative Commons is not a party to its public
386
- licenses. Notwithstanding, Creative Commons may elect to apply one of
387
- its public licenses to material it publishes and in those instances
388
- will be considered the “Licensor.” The text of the Creative Commons
389
- public licenses is dedicated to the public domain under the CC0 Public
390
- Domain Dedication. Except for the limited purpose of indicating that
391
- material is shared under a Creative Commons public license or as
392
- otherwise permitted by the Creative Commons policies published at
393
- creativecommons.org/policies, Creative Commons does not authorize the
394
- use of the trademark "Creative Commons" or any other trademark or logo
395
- of Creative Commons without its prior written consent including,
396
- without limitation, in connection with any unauthorized modifications
397
- to any of its public licenses or any other arrangements,
398
- understandings, or agreements concerning use of licensed material. For
399
- the avoidance of doubt, this paragraph does not form part of the
400
- public licenses.
401
-
402
- Creative Commons may be contacted at creativecommons.org.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/mano_data/mano_mean_params.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
3
- size 1178
 
 
 
 
WiLoR/pretrained_models/dataset_config.yaml DELETED
@@ -1,58 +0,0 @@
1
- ARCTIC-TRAIN:
2
- TYPE: ImageDataset
3
- URLS: wilor_training_data/dataset_tars/arctic-train/{000000..000176}.tar
4
- epoch_size: 177000
5
- BEDLAM-TRAIN:
6
- TYPE: ImageDataset
7
- URLS: wilor_training_data/dataset_tars/bedlam-train/{000000..000300}.tar
8
- epoch_size: 301000
9
- COCOW-TRAIN:
10
- TYPE: ImageDataset
11
- URLS: wilor_training_data/dataset_tars/cocow-train/{000000..000036}.tar
12
- epoch_size: 78666
13
- DEX-TRAIN:
14
- TYPE: ImageDataset
15
- URLS: wilor_training_data/dataset_tars/dex-train/{000000..000406}.tar
16
- epoch_size: 406888
17
- FREIHAND-MOCAP:
18
- DATASET_FILE: wilor_training_data/freihand_mocap.npz
19
- FREIHAND-TRAIN:
20
- TYPE: ImageDataset
21
- URLS: wilor_training_data/dataset_tars/freihand-train/{000000..000130}.tar
22
- epoch_size: 130240
23
- H2O3D-TRAIN:
24
- TYPE: ImageDataset
25
- URLS: wilor_training_data/dataset_tars/h2o3d-train/{000000..000060}.tar
26
- epoch_size: 121996
27
- HALPE-TRAIN:
28
- TYPE: ImageDataset
29
- URLS: wilor_training_data/dataset_tars/halpe-train/{000000..000022}.tar
30
- epoch_size: 34289
31
- HO3D-TRAIN:
32
- TYPE: ImageDataset
33
- URLS: wilor_training_data/dataset_tars/ho3d-train/{000000..000083}.tar
34
- epoch_size: 83325
35
- HOT3D-TRAIN:
36
- TYPE: ImageDataset
37
- URLS: wilor_training_data/dataset_tars/hot3d-train/{000000..000571}.tar
38
- epoch_size: 572000
39
- INTERHAND26M-TRAIN:
40
- TYPE: ImageDataset
41
- URLS: wilor_training_data/dataset_tars/interhand26m-train/{000000..001056}.tar
42
- epoch_size: 1424632
43
- MPIINZSL-TRAIN:
44
- TYPE: ImageDataset
45
- URLS: wilor_training_data/dataset_tars/mpiinzsl-train/{000000..000015}.tar
46
- epoch_size: 15184
47
- MTC-TRAIN:
48
- TYPE: ImageDataset
49
- URLS: wilor_training_data/dataset_tars/mtc-train/{000000..000306}.tar
50
- epoch_size: 363947
51
- REINTER-TRAIN:
52
- TYPE: ImageDataset
53
- URLS: wilor_training_data/dataset_tars/reinter-train/{000000..000418}.tar
54
- epoch_size: 419000
55
- RHD-TRAIN:
56
- TYPE: ImageDataset
57
- URLS: wilor_training_data/dataset_tars/rhd-train/{000000..000041}.tar
58
- epoch_size: 61705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/pretrained_models/model_config.yaml DELETED
@@ -1,119 +0,0 @@
1
- task_name: train
2
- tags:
3
- - dev
4
- train: true
5
- test: false
6
- ckpt_path: null
7
- seed: null
8
- DATASETS:
9
- TRAIN:
10
- FREIHAND-TRAIN:
11
- WEIGHT: 0.2
12
- INTERHAND26M-TRAIN:
13
- WEIGHT: 0.1
14
- MTC-TRAIN:
15
- WEIGHT: 0.05
16
- RHD-TRAIN:
17
- WEIGHT: 0.05
18
- COCOW-TRAIN:
19
- WEIGHT: 0.05
20
- HALPE-TRAIN:
21
- WEIGHT: 0.05
22
- MPIINZSL-TRAIN:
23
- WEIGHT: 0.05
24
- HO3D-TRAIN:
25
- WEIGHT: 0.05
26
- H2O3D-TRAIN:
27
- WEIGHT: 0.05
28
- DEX-TRAIN:
29
- WEIGHT: 0.05
30
- BEDLAM-TRAIN:
31
- WEIGHT: 0.05
32
- REINTER-TRAIN:
33
- WEIGHT: 0.1
34
- HOT3D-TRAIN:
35
- WEIGHT: 0.05
36
- ARCTIC-TRAIN:
37
- WEIGHT: 0.1
38
- VAL:
39
- FREIHAND-TRAIN:
40
- WEIGHT: 1.0
41
- MOCAP: FREIHAND-MOCAP
42
- BETAS_REG: true
43
- CONFIG:
44
- SCALE_FACTOR: 0.3
45
- ROT_FACTOR: 30
46
- TRANS_FACTOR: 0.02
47
- COLOR_SCALE: 0.2
48
- ROT_AUG_RATE: 0.6
49
- TRANS_AUG_RATE: 0.5
50
- DO_FLIP: false
51
- FLIP_AUG_RATE: 0.0
52
- EXTREME_CROP_AUG_RATE: 0.0
53
- EXTREME_CROP_AUG_LEVEL: 1
54
- extras:
55
- ignore_warnings: false
56
- enforce_tags: true
57
- print_config: true
58
- exp_name: WiLoR
59
- MANO:
60
- DATA_DIR: mano_data
61
- MODEL_PATH: ${MANO.DATA_DIR}
62
- GENDER: neutral
63
- NUM_HAND_JOINTS: 15
64
- MEAN_PARAMS: ${MANO.DATA_DIR}/mano_mean_params.npz
65
- CREATE_BODY_POSE: false
66
- EXTRA:
67
- FOCAL_LENGTH: 5000
68
- NUM_LOG_IMAGES: 4
69
- NUM_LOG_SAMPLES_PER_IMAGE: 8
70
- PELVIS_IND: 0
71
- GENERAL:
72
- TOTAL_STEPS: 1000000
73
- LOG_STEPS: 1000
74
- VAL_STEPS: 1000
75
- CHECKPOINT_STEPS: 1000
76
- CHECKPOINT_SAVE_TOP_K: 1
77
- NUM_WORKERS: 8
78
- PREFETCH_FACTOR: 2
79
- TRAIN:
80
- LR: 1.0e-05
81
- WEIGHT_DECAY: 0.0001
82
- BATCH_SIZE: 32
83
- LOSS_REDUCTION: mean
84
- NUM_TRAIN_SAMPLES: 2
85
- NUM_TEST_SAMPLES: 64
86
- POSE_2D_NOISE_RATIO: 0.01
87
- SMPL_PARAM_NOISE_RATIO: 0.005
88
- MODEL:
89
- IMAGE_SIZE: 256
90
- IMAGE_MEAN:
91
- - 0.485
92
- - 0.456
93
- - 0.406
94
- IMAGE_STD:
95
- - 0.229
96
- - 0.224
97
- - 0.225
98
- BACKBONE:
99
- TYPE: vit
100
- PRETRAINED_WEIGHTS: training_data/vitpose_backbone.pth
101
- MANO_HEAD:
102
- TYPE: transformer_decoder
103
- IN_CHANNELS: 2048
104
- TRANSFORMER_DECODER:
105
- depth: 6
106
- heads: 8
107
- mlp_dim: 1024
108
- dim_head: 64
109
- dropout: 0.0
110
- emb_dropout: 0.0
111
- norm: layer
112
- context_dim: 1280
113
- LOSS_WEIGHTS:
114
- KEYPOINTS_3D: 0.05
115
- KEYPOINTS_2D: 0.01
116
- GLOBAL_ORIENT: 0.001
117
- HAND_POSE: 0.001
118
- BETAS: 0.0005
119
- ADVERSARIAL: 0.0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/requirements.txt DELETED
@@ -1,20 +0,0 @@
1
- numpy
2
- opencv-python
3
- pyrender
4
- pytorch-lightning
5
- scikit-image
6
- smplx==0.1.28
7
- yacs
8
- chumpy @ git+https://github.com/mattloper/chumpy
9
- timm
10
- einops
11
- xtcocotools
12
- pandas
13
- hydra-core
14
- hydra-submitit-launcher
15
- hydra-colorlog
16
- pyrootutils
17
- rich
18
- webdataset
19
- gradio
20
- ultralytics==8.1.34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/whim/Dataset_instructions.md DELETED
@@ -1,31 +0,0 @@
1
- ## WHIM Dataset
2
-
3
- **Annotations**
4
-
5
- The image annotations can be downloaded from the following Drive:
6
-
7
- ```
8
- https://drive.google.com/drive/folders/1d9Fw7LfnF5oJuA6yE8T3xA-u9p6H5ObZ
9
- ```
10
-
11
- **[Alternative]**: The image annotations can be also downloaded from Hugging Face:
12
- ```
13
- https://huggingface.co/datasets/rolpotamias/WHIM
14
- ```
15
- If you are using Hugging Face you might need to merge the training zip files into a single file before uncompressing:
16
- ```
17
- cat train_split.zip* > ~/train_split.zip
18
- ```
19
-
20
- **Images**
21
-
22
- To download the corresponding images you need to first download the YouTube videos and extract the specific frames.
23
- You will need to install ''pytubefix'' or any similar package to download YouTube videos:
24
- ```
25
- pip install -Iv pytubefix==8.12.2
26
- ```
27
- You can then run the following command to download the corresponding train/test images:
28
- ```
29
- python download_videos.py --mode {train/test}
30
- ```
31
- Please make sure that the data are downloaded in the same directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/whim/test_video_ids.json DELETED
@@ -1 +0,0 @@
1
- {"YynYZyoETto": {"res": [360, 480], "length": 4678, "fps": 29.97002997002997}, "_iirwC_DvJ0": {"res": [480, 854], "length": 7994, "fps": 29.97002997002997}, "ZMnb9TTsx98": {"res": [1080, 1920], "length": 8109, "fps": 29.97002997002997}, "IrSIHJ0-AaU": {"res": [360, 640], "length": 880, "fps": 30.0}, "w2ULyzWkZ3k": {"res": [1080, 1920], "length": 17032, "fps": 29.97002997002997}, "ivyqQreoVQA": {"res": [1080, 1440], "length": 9610, "fps": 29.97002997002997}, "R07f8kg1h8o": {"res": [1080, 1920], "length": 10726, "fps": 23.976023976023978}, "7S9q1kAVmc0": {"res": [720, 1280], "length": 53757, "fps": 25.0}, "_Ce7G35GIqA": {"res": [720, 1280], "length": 1620, "fps": 30.0}, "lhHkJ3InQOE": {"res": [240, 320], "length": 1600, "fps": 11.988011988011989}, "NXRHcCScubA": {"res": [1080, 1920], "length": 9785, "fps": 29.97002997002997}, "DjFX4idkS3o": {"res": [720, 1280], "length": 5046, "fps": 29.97002997002997}, "06kKvQp4SfM": {"res": [720, 1280], "length": 2661, "fps": 30.0}, "8NqJiAu9W3Y": {"res": [720, 1280], "length": 4738, "fps": 29.97002997002997}, "nN5Y--biYv4": {"res": [720, 1280], "length": 38380, "fps": 29.97}, "OiAlJIaWOBg": {"res": [720, 1280], "length": 10944, "fps": 30.0}, "nJa_omJBzoU": {"res": [720, 1280], "length": 4311, "fps": 29.97002997002997}, "ff_xcsFJ8Pw": {"res": [720, 1280], "length": 5631, "fps": 29.97}, "Y1mNu5iFwMg": {"res": [720, 1280], "length": 7060, "fps": 30.0}, "Ipe9xJCfuTM": {"res": [1080, 1920], "length": 52419, "fps": 29.97002997002997}, "vRkcw9SRems": {"res": [1080, 1920], "length": 10282, "fps": 23.976023976023978}, "ChIJjJyBjQ0": {"res": [1080, 1920], "length": 20228, "fps": 29.97002997002997}, "bxZtXdVvfpc": {"res": [1080, 1920], "length": 2369, "fps": 23.976023976023978}, "MPeXy2U4yJM": {"res": [1080, 1920], "length": 6760, "fps": 24.0}, "wnKnoui3THA": {"res": [1080, 1920], "length": 7934, "fps": 25.0}, "gnArvcWaH6I": {"res": [480, 720], "length": 6864, "fps": 29.97002997002997}}
 
 
WiLoR/whim/train_video_ids.json DELETED
The diff for this file is too large to render. See raw diff
 
WiLoR/wilor/configs/__init__.py DELETED
@@ -1,114 +0,0 @@
1
- import os
2
- from typing import Dict
3
- from yacs.config import CfgNode as CN
4
-
5
- CACHE_DIR_PRETRAINED = "./pretrained_models/"
6
-
7
- def to_lower(x: Dict) -> Dict:
8
- """
9
- Convert all dictionary keys to lowercase
10
- Args:
11
- x (dict): Input dictionary
12
- Returns:
13
- dict: Output dictionary with all keys converted to lowercase
14
- """
15
- return {k.lower(): v for k, v in x.items()}
16
-
17
- _C = CN(new_allowed=True)
18
-
19
- _C.GENERAL = CN(new_allowed=True)
20
- _C.GENERAL.RESUME = True
21
- _C.GENERAL.TIME_TO_RUN = 3300
22
- _C.GENERAL.VAL_STEPS = 100
23
- _C.GENERAL.LOG_STEPS = 100
24
- _C.GENERAL.CHECKPOINT_STEPS = 20000
25
- _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
26
- _C.GENERAL.SUMMARY_DIR = "tensorboard"
27
- _C.GENERAL.NUM_GPUS = 1
28
- _C.GENERAL.NUM_WORKERS = 4
29
- _C.GENERAL.MIXED_PRECISION = True
30
- _C.GENERAL.ALLOW_CUDA = True
31
- _C.GENERAL.PIN_MEMORY = False
32
- _C.GENERAL.DISTRIBUTED = False
33
- _C.GENERAL.LOCAL_RANK = 0
34
- _C.GENERAL.USE_SYNCBN = False
35
- _C.GENERAL.WORLD_SIZE = 1
36
-
37
- _C.TRAIN = CN(new_allowed=True)
38
- _C.TRAIN.NUM_EPOCHS = 100
39
- _C.TRAIN.BATCH_SIZE = 32
40
- _C.TRAIN.SHUFFLE = True
41
- _C.TRAIN.WARMUP = False
42
- _C.TRAIN.NORMALIZE_PER_IMAGE = False
43
- _C.TRAIN.CLIP_GRAD = False
44
- _C.TRAIN.CLIP_GRAD_VALUE = 1.0
45
- _C.LOSS_WEIGHTS = CN(new_allowed=True)
46
-
47
- _C.DATASETS = CN(new_allowed=True)
48
-
49
- _C.MODEL = CN(new_allowed=True)
50
- _C.MODEL.IMAGE_SIZE = 224
51
-
52
- _C.EXTRA = CN(new_allowed=True)
53
- _C.EXTRA.FOCAL_LENGTH = 5000
54
-
55
- _C.DATASETS.CONFIG = CN(new_allowed=True)
56
- _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
57
- _C.DATASETS.CONFIG.ROT_FACTOR = 30
58
- _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
59
- _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
60
- _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
61
- _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
62
- _C.DATASETS.CONFIG.DO_FLIP = False
63
- _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
64
- _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
65
-
66
- def default_config() -> CN:
67
- """
68
- Get a yacs CfgNode object with the default config values.
69
- """
70
- # Return a clone so that the defaults will not be altered
71
- # This is for the "local variable" use pattern
72
- return _C.clone()
73
-
74
- def dataset_config(name='datasets_tar.yaml') -> CN:
75
- """
76
- Get dataset config file
77
- Returns:
78
- CfgNode: Dataset config as a yacs CfgNode object.
79
- """
80
- cfg = CN(new_allowed=True)
81
- config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
82
- cfg.merge_from_file(config_file)
83
- cfg.freeze()
84
- return cfg
85
-
86
- def dataset_eval_config() -> CN:
87
- return dataset_config('datasets_eval.yaml')
88
-
89
- def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
90
- """
91
- Read a config file and optionally merge it with the default config file.
92
- Args:
93
- config_file (str): Path to config file.
94
- merge (bool): Whether to merge with the default config or not.
95
- Returns:
96
- CfgNode: Config as a yacs CfgNode object.
97
- """
98
- if merge:
99
- cfg = default_config()
100
- else:
101
- cfg = CN(new_allowed=True)
102
- cfg.merge_from_file(config_file)
103
-
104
- if update_cachedir:
105
- def update_path(path: str) -> str:
106
- if os.path.isabs(path):
107
- return path
108
- return os.path.join(CACHE_DIR_PRETRAINED, path)
109
-
110
- cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
111
- cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
112
-
113
- cfg.freeze()
114
- return cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/configs/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (6.02 kB)
 
WiLoR/wilor/datasets/utils.py DELETED
@@ -1,994 +0,0 @@
1
- """
2
- Parts of the code are taken or adapted from
3
- https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
4
- """
5
- import torch
6
- import numpy as np
7
- from skimage.transform import rotate, resize
8
- from skimage.filters import gaussian
9
- import random
10
- import cv2
11
- from typing import List, Dict, Tuple
12
- from yacs.config import CfgNode
13
-
14
- def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
15
- """Increase the size of the bounding box to match the target shape."""
16
- if target_aspect_ratio is None:
17
- return input_shape
18
-
19
- try:
20
- w , h = input_shape
21
- except (ValueError, TypeError):
22
- return input_shape
23
-
24
- w_t, h_t = target_aspect_ratio
25
- if h / w < h_t / w_t:
26
- h_new = max(w * h_t / w_t, h)
27
- w_new = w
28
- else:
29
- h_new = h
30
- w_new = max(h * w_t / h_t, w)
31
- if h_new < h or w_new < w:
32
- breakpoint()
33
- return np.array([w_new, h_new])
34
-
35
- def do_augmentation(aug_config: CfgNode) -> Tuple:
36
- """
37
- Compute random augmentation parameters.
38
- Args:
39
- aug_config (CfgNode): Config containing augmentation parameters.
40
- Returns:
41
- scale (float): Box rescaling factor.
42
- rot (float): Random image rotation.
43
- do_flip (bool): Whether to flip image or not.
44
- do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
45
- color_scale (List): Color rescaling factor
46
- tx (float): Random translation along the x axis.
47
- ty (float): Random translation along the y axis.
48
- """
49
-
50
- tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
51
- ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
52
- scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
53
- rot = np.clip(np.random.randn(), -2.0,
54
- 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
55
- do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
56
- do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
57
- extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
58
- # extreme_crop_lvl = 0
59
- c_up = 1.0 + aug_config.COLOR_SCALE
60
- c_low = 1.0 - aug_config.COLOR_SCALE
61
- color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
62
- return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
63
-
64
- def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
65
- """
66
- Rotate a 2D point on the x-y plane.
67
- Args:
68
- pt_2d (np.array): Input 2D point with shape (2,).
69
- rot_rad (float): Rotation angle
70
- Returns:
71
- np.array: Rotated 2D point.
72
- """
73
- x = pt_2d[0]
74
- y = pt_2d[1]
75
- sn, cs = np.sin(rot_rad), np.cos(rot_rad)
76
- xx = x * cs - y * sn
77
- yy = x * sn + y * cs
78
- return np.array([xx, yy], dtype=np.float32)
79
-
80
-
81
- def gen_trans_from_patch_cv(c_x: float, c_y: float,
82
- src_width: float, src_height: float,
83
- dst_width: float, dst_height: float,
84
- scale: float, rot: float) -> np.array:
85
- """
86
- Create transformation matrix for the bounding box crop.
87
- Args:
88
- c_x (float): Bounding box center x coordinate in the original image.
89
- c_y (float): Bounding box center y coordinate in the original image.
90
- src_width (float): Bounding box width.
91
- src_height (float): Bounding box height.
92
- dst_width (float): Output box width.
93
- dst_height (float): Output box height.
94
- scale (float): Rescaling factor for the bounding box (augmentation).
95
- rot (float): Random rotation applied to the box.
96
- Returns:
97
- trans (np.array): Target geometric transformation.
98
- """
99
- # augment size with scale
100
- src_w = src_width * scale
101
- src_h = src_height * scale
102
- src_center = np.zeros(2)
103
- src_center[0] = c_x
104
- src_center[1] = c_y
105
- # augment rotation
106
- rot_rad = np.pi * rot / 180
107
- src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
108
- src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
109
-
110
- dst_w = dst_width
111
- dst_h = dst_height
112
- dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
113
- dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
114
- dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
115
-
116
- src = np.zeros((3, 2), dtype=np.float32)
117
- src[0, :] = src_center
118
- src[1, :] = src_center + src_downdir
119
- src[2, :] = src_center + src_rightdir
120
-
121
- dst = np.zeros((3, 2), dtype=np.float32)
122
- dst[0, :] = dst_center
123
- dst[1, :] = dst_center + dst_downdir
124
- dst[2, :] = dst_center + dst_rightdir
125
-
126
- trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
127
-
128
- return trans
129
-
130
-
131
- def trans_point2d(pt_2d: np.array, trans: np.array):
132
- """
133
- Transform a 2D point using translation matrix trans.
134
- Args:
135
- pt_2d (np.array): Input 2D point with shape (2,).
136
- trans (np.array): Transformation matrix.
137
- Returns:
138
- np.array: Transformed 2D point.
139
- """
140
- src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
141
- dst_pt = np.dot(trans, src_pt)
142
- return dst_pt[0:2]
143
-
144
- def get_transform(center, scale, res, rot=0):
145
- """Generate transformation matrix."""
146
- """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
147
- h = 200 * scale
148
- t = np.zeros((3, 3))
149
- t[0, 0] = float(res[1]) / h
150
- t[1, 1] = float(res[0]) / h
151
- t[0, 2] = res[1] * (-float(center[0]) / h + .5)
152
- t[1, 2] = res[0] * (-float(center[1]) / h + .5)
153
- t[2, 2] = 1
154
- if not rot == 0:
155
- rot = -rot # To match direction of rotation from cropping
156
- rot_mat = np.zeros((3, 3))
157
- rot_rad = rot * np.pi / 180
158
- sn, cs = np.sin(rot_rad), np.cos(rot_rad)
159
- rot_mat[0, :2] = [cs, -sn]
160
- rot_mat[1, :2] = [sn, cs]
161
- rot_mat[2, 2] = 1
162
- # Need to rotate around center
163
- t_mat = np.eye(3)
164
- t_mat[0, 2] = -res[1] / 2
165
- t_mat[1, 2] = -res[0] / 2
166
- t_inv = t_mat.copy()
167
- t_inv[:2, 2] *= -1
168
- t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
169
- return t
170
-
171
-
172
- def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
173
- """Transform pixel location to different reference."""
174
- """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
175
- t = get_transform(center, scale, res, rot=rot)
176
- if invert:
177
- t = np.linalg.inv(t)
178
- new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
179
- new_pt = np.dot(t, new_pt)
180
- if as_int:
181
- new_pt = new_pt.astype(int)
182
- return new_pt[:2] + 1
183
-
184
- def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
185
- c_x = (ul[0] + br[0])/2
186
- c_y = (ul[1] + br[1])/2
187
- bb_width = patch_width = br[0] - ul[0]
188
- bb_height = patch_height = br[1] - ul[1]
189
- trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
190
- img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
191
- flags=cv2.INTER_LINEAR,
192
- borderMode=border_mode,
193
- borderValue=border_value
194
- )
195
-
196
- # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
197
- if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
198
- img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
199
- flags=cv2.INTER_LINEAR,
200
- borderMode=cv2.BORDER_CONSTANT,
201
- )
202
-
203
- return img_patch
204
-
205
- def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
206
- bb_width: float, bb_height: float,
207
- patch_width: float, patch_height: float,
208
- do_flip: bool, scale: float, rot: float,
209
- border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
210
- """
211
- Crop image according to the supplied bounding box.
212
- Args:
213
- img (np.array): Input image of shape (H, W, 3)
214
- c_x (float): Bounding box center x coordinate in the original image.
215
- c_y (float): Bounding box center y coordinate in the original image.
216
- bb_width (float): Bounding box width.
217
- bb_height (float): Bounding box height.
218
- patch_width (float): Output box width.
219
- patch_height (float): Output box height.
220
- do_flip (bool): Whether to flip image or not.
221
- scale (float): Rescaling factor for the bounding box (augmentation).
222
- rot (float): Random rotation applied to the box.
223
- Returns:
224
- img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
225
- trans (np.array): Transformation matrix.
226
- """
227
-
228
- img_height, img_width, img_channels = img.shape
229
- if do_flip:
230
- img = img[:, ::-1, :]
231
- c_x = img_width - c_x - 1
232
-
233
- trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
234
-
235
- #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
236
-
237
- # skimage
238
- center = np.zeros(2)
239
- center[0] = c_x
240
- center[1] = c_y
241
- res = np.zeros(2)
242
- res[0] = patch_width
243
- res[1] = patch_height
244
- # assumes bb_width = bb_height
245
- # assumes patch_width = patch_height
246
- assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
247
- assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
248
- scale1 = scale*bb_width/200.
249
-
250
- # Upper left point
251
- ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
252
- # Bottom right point
253
- br = np.array(transform([res[0] + 1,
254
- res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
255
-
256
- # Padding so that when rotated proper amount of context is included
257
- try:
258
- pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
259
- except:
260
- breakpoint()
261
- if not rot == 0:
262
- ul -= pad
263
- br += pad
264
-
265
-
266
- if False:
267
- # Old way of cropping image
268
- ul_int = ul.astype(int)
269
- br_int = br.astype(int)
270
- new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
271
- if len(img.shape) > 2:
272
- new_shape += [img.shape[2]]
273
- new_img = np.zeros(new_shape)
274
-
275
- # Range to fill new array
276
- new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
277
- new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
278
- # Range to sample from original image
279
- old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
280
- old_y = max(0, ul_int[1]), min(len(img), br_int[1])
281
- new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
282
- old_x[0]:old_x[1]]
283
-
284
- # New way of cropping image
285
- new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
286
-
287
- # print(f'{new_img.shape=}')
288
- # print(f'{new_img1.shape=}')
289
- # print(f'{np.allclose(new_img, new_img1)=}')
290
- # print(f'{img.dtype=}')
291
-
292
-
293
- if not rot == 0:
294
- # Remove padding
295
-
296
- new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
297
- new_img = new_img[pad:-pad, pad:-pad]
298
-
299
- if new_img.shape[0] < 1 or new_img.shape[1] < 1:
300
- print(f'{img.shape=}')
301
- print(f'{new_img.shape=}')
302
- print(f'{ul=}')
303
- print(f'{br=}')
304
- print(f'{pad=}')
305
- print(f'{rot=}')
306
-
307
- breakpoint()
308
-
309
- # resize image
310
- new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
311
-
312
- new_img = np.clip(new_img, 0, 255).astype(np.uint8)
313
-
314
- return new_img, trans
315
-
316
-
317
- def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
318
- bb_width: float, bb_height: float,
319
- patch_width: float, patch_height: float,
320
- do_flip: bool, scale: float, rot: float,
321
- border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
322
- """
323
- Crop the input image and return the crop and the corresponding transformation matrix.
324
- Args:
325
- img (np.array): Input image of shape (H, W, 3)
326
- c_x (float): Bounding box center x coordinate in the original image.
327
- c_y (float): Bounding box center y coordinate in the original image.
328
- bb_width (float): Bounding box width.
329
- bb_height (float): Bounding box height.
330
- patch_width (float): Output box width.
331
- patch_height (float): Output box height.
332
- do_flip (bool): Whether to flip image or not.
333
- scale (float): Rescaling factor for the bounding box (augmentation).
334
- rot (float): Random rotation applied to the box.
335
- Returns:
336
- img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
337
- trans (np.array): Transformation matrix.
338
- """
339
-
340
- img_height, img_width, img_channels = img.shape
341
- if do_flip:
342
- img = img[:, ::-1, :]
343
- c_x = img_width - c_x - 1
344
-
345
-
346
- trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
347
-
348
- img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
349
- flags=cv2.INTER_LINEAR,
350
- borderMode=border_mode,
351
- borderValue=border_value,
352
- )
353
- # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
354
- if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
355
- img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
356
- flags=cv2.INTER_LINEAR,
357
- borderMode=cv2.BORDER_CONSTANT,
358
- )
359
-
360
- return img_patch, trans
361
-
362
-
363
- def convert_cvimg_to_tensor(cvimg: np.array):
364
- """
365
- Convert image from HWC to CHW format.
366
- Args:
367
- cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
368
- Returns:
369
- np.array: Output image of shape (3, H, W).
370
- """
371
- # from h,w,c(OpenCV) to c,h,w
372
- img = cvimg.copy()
373
- img = np.transpose(img, (2, 0, 1))
374
- # from int to float
375
- img = img.astype(np.float32)
376
- return img
377
-
378
- def fliplr_params(mano_params: Dict, has_mano_params: Dict) -> Tuple[Dict, Dict]:
379
- """
380
- Flip MANO parameters when flipping the image.
381
- Args:
382
- mano_params (Dict): MANO parameter annotations.
383
- has_mano_params (Dict): Whether MANO annotations are valid.
384
- Returns:
385
- Dict, Dict: Flipped MANO parameters and valid flags.
386
- """
387
- global_orient = mano_params['global_orient'].copy()
388
- hand_pose = mano_params['hand_pose'].copy()
389
- betas = mano_params['betas'].copy()
390
- has_global_orient = has_mano_params['global_orient'].copy()
391
- has_hand_pose = has_mano_params['hand_pose'].copy()
392
- has_betas = has_mano_params['betas'].copy()
393
-
394
- global_orient[1::3] *= -1
395
- global_orient[2::3] *= -1
396
- hand_pose[1::3] *= -1
397
- hand_pose[2::3] *= -1
398
-
399
- mano_params = {'global_orient': global_orient.astype(np.float32),
400
- 'hand_pose': hand_pose.astype(np.float32),
401
- 'betas': betas.astype(np.float32)
402
- }
403
-
404
- has_mano_params = {'global_orient': has_global_orient,
405
- 'hand_pose': has_hand_pose,
406
- 'betas': has_betas
407
- }
408
-
409
- return mano_params, has_mano_params
410
-
411
-
412
- def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
413
- """
414
- Flip 2D or 3D keypoints.
415
- Args:
416
- joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
417
- flip_permutation (List): Permutation to apply after flipping.
418
- Returns:
419
- np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
420
- """
421
- joints = joints.copy()
422
- # Flip horizontal
423
- joints[:, 0] = width - joints[:, 0] - 1
424
- joints = joints[flip_permutation, :]
425
-
426
- return joints
427
-
428
- def keypoint_3d_processing(keypoints_3d: np.array, flip_permutation: List[int], rot: float, do_flip: float) -> np.array:
429
- """
430
- Process 3D keypoints (rotation/flipping).
431
- Args:
432
- keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
433
- flip_permutation (List): Permutation to apply after flipping.
434
- rot (float): Random rotation applied to the keypoints.
435
- do_flip (bool): Whether to flip keypoints or not.
436
- Returns:
437
- np.array: Transformed 3D keypoints with shape (N, 4).
438
- """
439
- if do_flip:
440
- keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation)
441
- # in-plane rotation
442
- rot_mat = np.eye(3)
443
- if not rot == 0:
444
- rot_rad = -rot * np.pi / 180
445
- sn,cs = np.sin(rot_rad), np.cos(rot_rad)
446
- rot_mat[0,:2] = [cs, -sn]
447
- rot_mat[1,:2] = [sn, cs]
448
- keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
449
- # flip the x coordinates
450
- keypoints_3d = keypoints_3d.astype('float32')
451
- return keypoints_3d
452
-
453
- def rot_aa(aa: np.array, rot: float) -> np.array:
454
- """
455
- Rotate axis angle parameters.
456
- Args:
457
- aa (np.array): Axis-angle vector of shape (3,).
458
- rot (np.array): Rotation angle in degrees.
459
- Returns:
460
- np.array: Rotated axis-angle vector.
461
- """
462
- # pose parameters
463
- R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
464
- [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
465
- [0, 0, 1]])
466
- # find the rotation of the hand in camera frame
467
- per_rdg, _ = cv2.Rodrigues(aa)
468
- # apply the global rotation to the global orientation
469
- resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
470
- aa = (resrot.T)[0]
471
- return aa.astype(np.float32)
472
-
473
- def mano_param_processing(mano_params: Dict, has_mano_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
474
- """
475
- Apply random augmentations to the MANO parameters.
476
- Args:
477
- mano_params (Dict): MANO parameter annotations.
478
- has_mano_params (Dict): Whether mano annotations are valid.
479
- rot (float): Random rotation applied to the keypoints.
480
- do_flip (bool): Whether to flip keypoints or not.
481
- Returns:
482
- Dict, Dict: Transformed MANO parameters and valid flags.
483
- """
484
- if do_flip:
485
- mano_params, has_mano_params = fliplr_params(mano_params, has_mano_params)
486
- mano_params['global_orient'] = rot_aa(mano_params['global_orient'], rot)
487
- return mano_params, has_mano_params
488
-
489
-
490
-
491
- def get_example(img_path: str|np.ndarray, center_x: float, center_y: float,
492
- width: float, height: float,
493
- keypoints_2d: np.array, keypoints_3d: np.array,
494
- mano_params: Dict, has_mano_params: Dict,
495
- flip_kp_permutation: List[int],
496
- patch_width: int, patch_height: int,
497
- mean: np.array, std: np.array,
498
- do_augment: bool, is_right: bool, augm_config: CfgNode,
499
- is_bgr: bool = True,
500
- use_skimage_antialias: bool = False,
501
- border_mode: int = cv2.BORDER_CONSTANT,
502
- return_trans: bool = False) -> Tuple:
503
- """
504
- Get an example from the dataset and (possibly) apply random augmentations.
505
- Args:
506
- img_path (str): Image filename
507
- center_x (float): Bounding box center x coordinate in the original image.
508
- center_y (float): Bounding box center y coordinate in the original image.
509
- width (float): Bounding box width.
510
- height (float): Bounding box height.
511
- keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
512
- keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
513
- mano_params (Dict): MANO parameter annotations.
514
- has_mano_params (Dict): Whether MANO annotations are valid.
515
- flip_kp_permutation (List): Permutation to apply to the keypoints after flipping.
516
- patch_width (float): Output box width.
517
- patch_height (float): Output box height.
518
- mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
519
- std (np.array): Array of shape (3,) containing the std for normalizing the input image.
520
- do_augment (bool): Whether to apply data augmentation or not.
521
- aug_config (CfgNode): Config containing augmentation parameters.
522
- Returns:
523
- return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
524
- img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
525
- keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
526
- keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
527
- mano_params (Dict): Transformed MANO parameters.
528
- has_mano_params (Dict): Valid flag for transformed MANO parameters.
529
- img_size (np.array): Image size of the original image.
530
- """
531
- if isinstance(img_path, str):
532
- # 1. load image
533
- cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
534
- if not isinstance(cvimg, np.ndarray):
535
- raise IOError("Fail to read %s" % img_path)
536
- elif isinstance(img_path, np.ndarray):
537
- cvimg = img_path
538
- else:
539
- raise TypeError('img_path must be either a string or a numpy array')
540
- img_height, img_width, img_channels = cvimg.shape
541
-
542
- img_size = np.array([img_height, img_width])
543
-
544
- # 2. get augmentation params
545
- if do_augment:
546
- scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
547
- else:
548
- scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0.
549
-
550
- # if it's a left hand, we flip
551
- if not is_right:
552
- do_flip = True
553
-
554
- if width < 1 or height < 1:
555
- breakpoint()
556
-
557
- if do_extreme_crop:
558
- if extreme_crop_lvl == 0:
559
- center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
560
- elif extreme_crop_lvl == 1:
561
- center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d)
562
-
563
- THRESH = 4
564
- if width1 < THRESH or height1 < THRESH:
565
- # print(f'{do_extreme_crop=}')
566
- # print(f'width: {width}, height: {height}')
567
- # print(f'width1: {width1}, height1: {height1}')
568
- # print(f'center_x: {center_x}, center_y: {center_y}')
569
- # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
570
- # print(f'keypoints_2d: {keypoints_2d}')
571
- # print(f'\n\n', flush=True)
572
- # breakpoint()
573
- pass
574
- # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
575
- else:
576
- center_x, center_y, width, height = center_x1, center_y1, width1, height1
577
-
578
- center_x += width * tx
579
- center_y += height * ty
580
-
581
- # Process 3D keypoints
582
- keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip)
583
-
584
- # 3. generate image patch
585
- if use_skimage_antialias:
586
- # Blur image to avoid aliasing artifacts
587
- downsampling_factor = (patch_width / (width*scale))
588
- if downsampling_factor > 1.1:
589
- cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0)
590
-
591
- img_patch_cv, trans = generate_image_patch_cv2(cvimg,
592
- center_x, center_y,
593
- width, height,
594
- patch_width, patch_height,
595
- do_flip, scale, rot,
596
- border_mode=border_mode)
597
-
598
- # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
599
- # center_x, center_y,
600
- # width, height,
601
- # patch_width, patch_height,
602
- # do_flip, scale, rot,
603
- # border_mode=border_mode)
604
-
605
- image = img_patch_cv.copy()
606
- if is_bgr:
607
- image = image[:, :, ::-1]
608
- img_patch_cv = image.copy()
609
- img_patch = convert_cvimg_to_tensor(image)
610
-
611
-
612
- mano_params, has_mano_params = mano_param_processing(mano_params, has_mano_params, rot, do_flip)
613
-
614
- # apply normalization
615
- for n_c in range(min(img_channels, 3)):
616
- img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
617
- if mean is not None and std is not None:
618
- img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
619
- if do_flip:
620
- keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation)
621
-
622
-
623
- for n_jt in range(len(keypoints_2d)):
624
- keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
625
- keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
626
-
627
- if not return_trans:
628
- return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
629
- else:
630
- return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans
631
-
632
- def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
633
- """
634
- Extreme cropping: Crop the box up to the hip locations.
635
- Args:
636
- center_x (float): x coordinate of the bounding box center.
637
- center_y (float): y coordinate of the bounding box center.
638
- width (float): Bounding box width.
639
- height (float): Bounding box height.
640
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
641
- Returns:
642
- center_x (float): x coordinate of the new bounding box center.
643
- center_y (float): y coordinate of the new bounding box center.
644
- width (float): New bounding box width.
645
- height (float): New bounding box height.
646
- """
647
- keypoints_2d = keypoints_2d.copy()
648
- lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
649
- keypoints_2d[lower_body_keypoints, :] = 0
650
- if keypoints_2d[:, -1].sum() > 1:
651
- center, scale = get_bbox(keypoints_2d)
652
- center_x = center[0]
653
- center_y = center[1]
654
- width = 1.1 * scale[0]
655
- height = 1.1 * scale[1]
656
- return center_x, center_y, width, height
657
-
658
-
659
- def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
660
- """
661
- Extreme cropping: Crop the box up to the shoulder locations.
662
- Args:
663
- center_x (float): x coordinate of the bounding box center.
664
- center_y (float): y coordinate of the bounding box center.
665
- width (float): Bounding box width.
666
- height (float): Bounding box height.
667
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
668
- Returns:
669
- center_x (float): x coordinate of the new bounding box center.
670
- center_y (float): y coordinate of the new bounding box center.
671
- width (float): New bounding box width.
672
- height (float): New bounding box height.
673
- """
674
- keypoints_2d = keypoints_2d.copy()
675
- lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
676
- keypoints_2d[lower_body_keypoints, :] = 0
677
- center, scale = get_bbox(keypoints_2d)
678
- if keypoints_2d[:, -1].sum() > 1:
679
- center, scale = get_bbox(keypoints_2d)
680
- center_x = center[0]
681
- center_y = center[1]
682
- width = 1.2 * scale[0]
683
- height = 1.2 * scale[1]
684
- return center_x, center_y, width, height
685
-
686
- def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
687
- """
688
- Extreme cropping: Crop the box and keep on only the head.
689
- Args:
690
- center_x (float): x coordinate of the bounding box center.
691
- center_y (float): y coordinate of the bounding box center.
692
- width (float): Bounding box width.
693
- height (float): Bounding box height.
694
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
695
- Returns:
696
- center_x (float): x coordinate of the new bounding box center.
697
- center_y (float): y coordinate of the new bounding box center.
698
- width (float): New bounding box width.
699
- height (float): New bounding box height.
700
- """
701
- keypoints_2d = keypoints_2d.copy()
702
- lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
703
- keypoints_2d[lower_body_keypoints, :] = 0
704
- if keypoints_2d[:, -1].sum() > 1:
705
- center, scale = get_bbox(keypoints_2d)
706
- center_x = center[0]
707
- center_y = center[1]
708
- width = 1.3 * scale[0]
709
- height = 1.3 * scale[1]
710
- return center_x, center_y, width, height
711
-
712
- def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
713
- """
714
- Extreme cropping: Crop the box and keep on only the torso.
715
- Args:
716
- center_x (float): x coordinate of the bounding box center.
717
- center_y (float): y coordinate of the bounding box center.
718
- width (float): Bounding box width.
719
- height (float): Bounding box height.
720
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
721
- Returns:
722
- center_x (float): x coordinate of the new bounding box center.
723
- center_y (float): y coordinate of the new bounding box center.
724
- width (float): New bounding box width.
725
- height (float): New bounding box height.
726
- """
727
- keypoints_2d = keypoints_2d.copy()
728
- nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
729
- keypoints_2d[nontorso_body_keypoints, :] = 0
730
- if keypoints_2d[:, -1].sum() > 1:
731
- center, scale = get_bbox(keypoints_2d)
732
- center_x = center[0]
733
- center_y = center[1]
734
- width = 1.1 * scale[0]
735
- height = 1.1 * scale[1]
736
- return center_x, center_y, width, height
737
-
738
- def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
739
- """
740
- Extreme cropping: Crop the box and keep on only the right arm.
741
- Args:
742
- center_x (float): x coordinate of the bounding box center.
743
- center_y (float): y coordinate of the bounding box center.
744
- width (float): Bounding box width.
745
- height (float): Bounding box height.
746
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
747
- Returns:
748
- center_x (float): x coordinate of the new bounding box center.
749
- center_y (float): y coordinate of the new bounding box center.
750
- width (float): New bounding box width.
751
- height (float): New bounding box height.
752
- """
753
- keypoints_2d = keypoints_2d.copy()
754
- nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
755
- keypoints_2d[nonrightarm_body_keypoints, :] = 0
756
- if keypoints_2d[:, -1].sum() > 1:
757
- center, scale = get_bbox(keypoints_2d)
758
- center_x = center[0]
759
- center_y = center[1]
760
- width = 1.1 * scale[0]
761
- height = 1.1 * scale[1]
762
- return center_x, center_y, width, height
763
-
764
- def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
765
- """
766
- Extreme cropping: Crop the box and keep on only the left arm.
767
- Args:
768
- center_x (float): x coordinate of the bounding box center.
769
- center_y (float): y coordinate of the bounding box center.
770
- width (float): Bounding box width.
771
- height (float): Bounding box height.
772
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
773
- Returns:
774
- center_x (float): x coordinate of the new bounding box center.
775
- center_y (float): y coordinate of the new bounding box center.
776
- width (float): New bounding box width.
777
- height (float): New bounding box height.
778
- """
779
- keypoints_2d = keypoints_2d.copy()
780
- nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
781
- keypoints_2d[nonleftarm_body_keypoints, :] = 0
782
- if keypoints_2d[:, -1].sum() > 1:
783
- center, scale = get_bbox(keypoints_2d)
784
- center_x = center[0]
785
- center_y = center[1]
786
- width = 1.1 * scale[0]
787
- height = 1.1 * scale[1]
788
- return center_x, center_y, width, height
789
-
790
- def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
791
- """
792
- Extreme cropping: Crop the box and keep on only the legs.
793
- Args:
794
- center_x (float): x coordinate of the bounding box center.
795
- center_y (float): y coordinate of the bounding box center.
796
- width (float): Bounding box width.
797
- height (float): Bounding box height.
798
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
799
- Returns:
800
- center_x (float): x coordinate of the new bounding box center.
801
- center_y (float): y coordinate of the new bounding box center.
802
- width (float): New bounding box width.
803
- height (float): New bounding box height.
804
- """
805
- keypoints_2d = keypoints_2d.copy()
806
- nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
807
- keypoints_2d[nonlegs_body_keypoints, :] = 0
808
- if keypoints_2d[:, -1].sum() > 1:
809
- center, scale = get_bbox(keypoints_2d)
810
- center_x = center[0]
811
- center_y = center[1]
812
- width = 1.1 * scale[0]
813
- height = 1.1 * scale[1]
814
- return center_x, center_y, width, height
815
-
816
- def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
817
- """
818
- Extreme cropping: Crop the box and keep on only the right leg.
819
- Args:
820
- center_x (float): x coordinate of the bounding box center.
821
- center_y (float): y coordinate of the bounding box center.
822
- width (float): Bounding box width.
823
- height (float): Bounding box height.
824
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
825
- Returns:
826
- center_x (float): x coordinate of the new bounding box center.
827
- center_y (float): y coordinate of the new bounding box center.
828
- width (float): New bounding box width.
829
- height (float): New bounding box height.
830
- """
831
- keypoints_2d = keypoints_2d.copy()
832
- nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
833
- keypoints_2d[nonrightleg_body_keypoints, :] = 0
834
- if keypoints_2d[:, -1].sum() > 1:
835
- center, scale = get_bbox(keypoints_2d)
836
- center_x = center[0]
837
- center_y = center[1]
838
- width = 1.1 * scale[0]
839
- height = 1.1 * scale[1]
840
- return center_x, center_y, width, height
841
-
842
- def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
843
- """
844
- Extreme cropping: Crop the box and keep on only the left leg.
845
- Args:
846
- center_x (float): x coordinate of the bounding box center.
847
- center_y (float): y coordinate of the bounding box center.
848
- width (float): Bounding box width.
849
- height (float): Bounding box height.
850
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
851
- Returns:
852
- center_x (float): x coordinate of the new bounding box center.
853
- center_y (float): y coordinate of the new bounding box center.
854
- width (float): New bounding box width.
855
- height (float): New bounding box height.
856
- """
857
- keypoints_2d = keypoints_2d.copy()
858
- nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
859
- keypoints_2d[nonleftleg_body_keypoints, :] = 0
860
- if keypoints_2d[:, -1].sum() > 1:
861
- center, scale = get_bbox(keypoints_2d)
862
- center_x = center[0]
863
- center_y = center[1]
864
- width = 1.1 * scale[0]
865
- height = 1.1 * scale[1]
866
- return center_x, center_y, width, height
867
-
868
- def full_body(keypoints_2d: np.array) -> bool:
869
- """
870
- Check if all main body joints are visible.
871
- Args:
872
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
873
- Returns:
874
- bool: True if all main body joints are visible.
875
- """
876
-
877
- body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
878
- body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
879
- return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
880
-
881
- def upper_body(keypoints_2d: np.array):
882
- """
883
- Check if all upper body joints are visible.
884
- Args:
885
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
886
- Returns:
887
- bool: True if all main body joints are visible.
888
- """
889
- lower_body_keypoints_openpose = [10, 11, 13, 14]
890
- lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
891
- upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
892
- upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
893
- return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
894
- and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
895
-
896
- def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
897
- """
898
- Get center and scale for bounding box from openpose detections.
899
- Args:
900
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
901
- rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
902
- Returns:
903
- center (np.array): Array of shape (2,) containing the new bounding box center.
904
- scale (float): New bounding box scale.
905
- """
906
- valid = keypoints_2d[:,-1] > 0
907
- valid_keypoints = keypoints_2d[valid][:,:-1]
908
- center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
909
- bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
910
- # adjust bounding box tightness
911
- scale = bbox_size
912
- scale *= rescale
913
- return center, scale
914
-
915
- def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
916
- """
917
- Perform extreme cropping
918
- Args:
919
- center_x (float): x coordinate of bounding box center.
920
- center_y (float): y coordinate of bounding box center.
921
- width (float): bounding box width.
922
- height (float): bounding box height.
923
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
924
- rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
925
- Returns:
926
- center_x (float): x coordinate of bounding box center.
927
- center_y (float): y coordinate of bounding box center.
928
- width (float): bounding box width.
929
- height (float): bounding box height.
930
- """
931
- p = torch.rand(1).item()
932
- if full_body(keypoints_2d):
933
- if p < 0.7:
934
- center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
935
- elif p < 0.9:
936
- center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
937
- else:
938
- center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
939
- elif upper_body(keypoints_2d):
940
- if p < 0.9:
941
- center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
942
- else:
943
- center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
944
-
945
- return center_x, center_y, max(width, height), max(width, height)
946
-
947
- def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
948
- """
949
- Perform aggressive extreme cropping
950
- Args:
951
- center_x (float): x coordinate of bounding box center.
952
- center_y (float): y coordinate of bounding box center.
953
- width (float): bounding box width.
954
- height (float): bounding box height.
955
- keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
956
- rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
957
- Returns:
958
- center_x (float): x coordinate of bounding box center.
959
- center_y (float): y coordinate of bounding box center.
960
- width (float): bounding box width.
961
- height (float): bounding box height.
962
- """
963
- p = torch.rand(1).item()
964
- if full_body(keypoints_2d):
965
- if p < 0.2:
966
- center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
967
- elif p < 0.3:
968
- center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
969
- elif p < 0.4:
970
- center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
971
- elif p < 0.5:
972
- center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
973
- elif p < 0.6:
974
- center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
975
- elif p < 0.7:
976
- center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
977
- elif p < 0.8:
978
- center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
979
- elif p < 0.9:
980
- center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
981
- else:
982
- center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
983
- elif upper_body(keypoints_2d):
984
- if p < 0.2:
985
- center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
986
- elif p < 0.4:
987
- center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
988
- elif p < 0.6:
989
- center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
990
- elif p < 0.8:
991
- center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
992
- else:
993
- center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
994
- return center_x, center_y, max(width, height), max(width, height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/datasets/vitdet_dataset.py DELETED
@@ -1,95 +0,0 @@
1
- from typing import Dict
2
-
3
- import cv2
4
- import numpy as np
5
- from skimage.filters import gaussian
6
- from yacs.config import CfgNode
7
- import torch
8
-
9
- from .utils import (convert_cvimg_to_tensor,
10
- expand_to_aspect_ratio,
11
- generate_image_patch_cv2)
12
-
13
- DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
14
- DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
15
-
16
- class ViTDetDataset(torch.utils.data.Dataset):
17
-
18
- def __init__(self,
19
- cfg: CfgNode,
20
- img_cv2: np.array,
21
- boxes: np.array,
22
- right: np.array,
23
- rescale_factor=2.5,
24
- train: bool = False,
25
- **kwargs):
26
- super().__init__()
27
- self.cfg = cfg
28
- self.img_cv2 = img_cv2
29
- # self.boxes = boxes
30
-
31
- assert train == False, "ViTDetDataset is only for inference"
32
- self.train = train
33
- self.img_size = cfg.MODEL.IMAGE_SIZE
34
- self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
35
- self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
36
-
37
- # Preprocess annotations
38
- boxes = boxes.astype(np.float32)
39
- self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
40
- self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
41
- self.personid = np.arange(len(boxes), dtype=np.int32)
42
- self.right = right.astype(np.float32)
43
-
44
- def __len__(self) -> int:
45
- return len(self.personid)
46
-
47
- def __getitem__(self, idx: int) -> Dict[str, np.array]:
48
-
49
- center = self.center[idx].copy()
50
- center_x = center[0]
51
- center_y = center[1]
52
-
53
- scale = self.scale[idx]
54
- BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
55
- bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
56
-
57
- patch_width = patch_height = self.img_size
58
-
59
- right = self.right[idx].copy()
60
- flip = right == 0
61
-
62
- # 3. generate image patch
63
- # if use_skimage_antialias:
64
- cvimg = self.img_cv2.copy()
65
- if True:
66
- # Blur image to avoid aliasing artifacts
67
- downsampling_factor = ((bbox_size*1.0) / patch_width)
68
- #print(f'{downsampling_factor=}')
69
- downsampling_factor = downsampling_factor / 2.0
70
- if downsampling_factor > 1.1:
71
- cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
72
-
73
-
74
- img_patch_cv, trans = generate_image_patch_cv2(cvimg,
75
- center_x, center_y,
76
- bbox_size, bbox_size,
77
- patch_width, patch_height,
78
- flip, 1.0, 0,
79
- border_mode=cv2.BORDER_CONSTANT)
80
- img_patch_cv = img_patch_cv[:, :, ::-1]
81
- img_patch = convert_cvimg_to_tensor(img_patch_cv)
82
-
83
- # apply normalization
84
- for n_c in range(min(self.img_cv2.shape[2], 3)):
85
- img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
86
-
87
- item = {
88
- 'img': img_patch,
89
- 'personid': int(self.personid[idx]),
90
- }
91
- item['box_center'] = self.center[idx].copy()
92
- item['box_size'] = bbox_size
93
- item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]])
94
- item['right'] = self.right[idx].copy()
95
- return item
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/__init__.py DELETED
@@ -1,36 +0,0 @@
1
- from .mano_wrapper import MANO
2
- from .wilor import WiLoR
3
-
4
- from .discriminator import Discriminator
5
-
6
- def load_wilor(checkpoint_path, cfg_path):
7
- from pathlib import Path
8
- from wilor.configs import get_config
9
- print('Loading ', checkpoint_path)
10
- model_cfg = get_config(cfg_path, update_cachedir=True)
11
-
12
- # Override some config values, to crop bbox correctly
13
- if ('vit' in model_cfg.MODEL.BACKBONE.TYPE) and ('BBOX_SHAPE' not in model_cfg.MODEL):
14
-
15
- model_cfg.defrost()
16
- assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
17
- model_cfg.MODEL.BBOX_SHAPE = [192,256]
18
- model_cfg.freeze()
19
-
20
- # Update config to be compatible with demo
21
- if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
22
- model_cfg.defrost()
23
- model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
24
- model_cfg.freeze()
25
-
26
- # Update config to be compatible with demo
27
-
28
- if ('DATA_DIR' in model_cfg.MANO):
29
- model_cfg.defrost()
30
- model_cfg.MANO.DATA_DIR = './mano_data/'
31
- model_cfg.MANO.MODEL_PATH = './mano_data/'
32
- model_cfg.MANO.MEAN_PARAMS = './mano_data/mano_mean_params.npz'
33
- model_cfg.freeze()
34
-
35
- model = WiLoR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
36
- return model, model_cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (2.02 kB)
 
WiLoR/wilor/models/__pycache__/discriminator.cpython-311.pyc DELETED
Binary file (6.52 kB)
 
WiLoR/wilor/models/__pycache__/losses.cpython-311.pyc DELETED
Binary file (6.87 kB)
 
WiLoR/wilor/models/__pycache__/mano_wrapper.cpython-311.pyc DELETED
Binary file (3.43 kB)
 
WiLoR/wilor/models/__pycache__/wilor.cpython-311.pyc DELETED
Binary file (24.1 kB)
 
WiLoR/wilor/models/backbones/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .vit import vit
2
-
3
- def create_backbone(cfg):
4
- if cfg.MODEL.BACKBONE.TYPE == 'vit':
5
- return vit(cfg)
6
- elif cfg.MODEL.BACKBONE.TYPE == 'fast_vit':
7
- import torch
8
- import sys
9
- from timm.models import create_model
10
- #from models.modules.mobileone import reparameterize_model
11
- fast_vit = create_model("fastvit_ma36", drop_path_rate=0.2)
12
- checkpoint = torch.load('./pretrained_models/fastvit_ma36.pt')
13
- fast_vit.load_state_dict(checkpoint['state_dict'])
14
- return fast_vit
15
-
16
- else:
17
- raise NotImplementedError('Backbone type is not implemented')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (737 Bytes)
 
WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.09 kB)
 
WiLoR/wilor/models/backbones/__pycache__/vit.cpython-310.pyc DELETED
Binary file (13.2 kB)
 
WiLoR/wilor/models/backbones/__pycache__/vit.cpython-311.pyc DELETED
Binary file (26.9 kB)
 
WiLoR/wilor/models/backbones/vit.py DELETED
@@ -1,410 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import math
3
- import numpy as np
4
- import torch
5
- from functools import partial
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torch.utils.checkpoint as checkpoint
9
- from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
10
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
11
-
12
- def vit(cfg):
13
- return ViT(
14
- img_size=(256, 192),
15
- patch_size=16,
16
- embed_dim=1280,
17
- depth=32,
18
- num_heads=16,
19
- ratio=1,
20
- use_checkpoint=False,
21
- mlp_ratio=4,
22
- qkv_bias=True,
23
- drop_path_rate=0.55,
24
- cfg = cfg
25
- )
26
-
27
- def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
28
- """
29
- Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
30
- dimension for the original embeddings.
31
- Args:
32
- abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
33
- has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
34
- hw (Tuple): size of input image tokens.
35
-
36
- Returns:
37
- Absolute positional embeddings after processing with shape (1, H, W, C)
38
- """
39
- cls_token = None
40
- B, L, C = abs_pos.shape
41
- if has_cls_token:
42
- cls_token = abs_pos[:, 0:1]
43
- abs_pos = abs_pos[:, 1:]
44
-
45
- if ori_h != h or ori_w != w:
46
- new_abs_pos = F.interpolate(
47
- abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
48
- size=(h, w),
49
- mode="bicubic",
50
- align_corners=False,
51
- ).permute(0, 2, 3, 1).reshape(B, -1, C)
52
-
53
- else:
54
- new_abs_pos = abs_pos
55
-
56
- if cls_token is not None:
57
- new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
58
- return new_abs_pos
59
-
60
- class DropPath(nn.Module):
61
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
62
- """
63
- def __init__(self, drop_prob=None):
64
- super(DropPath, self).__init__()
65
- self.drop_prob = drop_prob
66
-
67
- def forward(self, x):
68
- return drop_path(x, self.drop_prob, self.training)
69
-
70
- def extra_repr(self):
71
- return 'p={}'.format(self.drop_prob)
72
-
73
- class Mlp(nn.Module):
74
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
75
- super().__init__()
76
- out_features = out_features or in_features
77
- hidden_features = hidden_features or in_features
78
- self.fc1 = nn.Linear(in_features, hidden_features)
79
- self.act = act_layer()
80
- self.fc2 = nn.Linear(hidden_features, out_features)
81
- self.drop = nn.Dropout(drop)
82
-
83
- def forward(self, x):
84
- x = self.fc1(x)
85
- x = self.act(x)
86
- x = self.fc2(x)
87
- x = self.drop(x)
88
- return x
89
-
90
- class Attention(nn.Module):
91
- def __init__(
92
- self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
93
- proj_drop=0., attn_head_dim=None,):
94
- super().__init__()
95
- self.num_heads = num_heads
96
- head_dim = dim // num_heads
97
- self.dim = dim
98
-
99
- if attn_head_dim is not None:
100
- head_dim = attn_head_dim
101
- all_head_dim = head_dim * self.num_heads
102
-
103
- self.scale = qk_scale or head_dim ** -0.5
104
-
105
- self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
106
-
107
- self.attn_drop = nn.Dropout(attn_drop)
108
- self.proj = nn.Linear(all_head_dim, dim)
109
- self.proj_drop = nn.Dropout(proj_drop)
110
-
111
- def forward(self, x):
112
- B, N, C = x.shape
113
- qkv = self.qkv(x)
114
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
115
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116
-
117
- q = q * self.scale
118
- attn = (q @ k.transpose(-2, -1))
119
-
120
- attn = attn.softmax(dim=-1)
121
- attn = self.attn_drop(attn)
122
-
123
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
124
- x = self.proj(x)
125
- x = self.proj_drop(x)
126
-
127
- return x
128
-
129
- class Block(nn.Module):
130
-
131
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
132
- drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
133
- norm_layer=nn.LayerNorm, attn_head_dim=None
134
- ):
135
- super().__init__()
136
-
137
- self.norm1 = norm_layer(dim)
138
- self.attn = Attention(
139
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
140
- attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
141
- )
142
-
143
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
144
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
145
- self.norm2 = norm_layer(dim)
146
- mlp_hidden_dim = int(dim * mlp_ratio)
147
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
148
-
149
- def forward(self, x):
150
- x = x + self.drop_path(self.attn(self.norm1(x)))
151
- x = x + self.drop_path(self.mlp(self.norm2(x)))
152
- return x
153
-
154
-
155
- class PatchEmbed(nn.Module):
156
- """ Image to Patch Embedding
157
- """
158
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
159
- super().__init__()
160
- img_size = to_2tuple(img_size)
161
- patch_size = to_2tuple(patch_size)
162
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
163
- self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
164
- self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
165
- self.img_size = img_size
166
- self.patch_size = patch_size
167
- self.num_patches = num_patches
168
-
169
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
170
-
171
- def forward(self, x, **kwargs):
172
- B, C, H, W = x.shape
173
- x = self.proj(x)
174
- Hp, Wp = x.shape[2], x.shape[3]
175
-
176
- x = x.flatten(2).transpose(1, 2)
177
- return x, (Hp, Wp)
178
-
179
-
180
- class HybridEmbed(nn.Module):
181
- """ CNN Feature Map Embedding
182
- Extract feature map from CNN, flatten, project to embedding dim.
183
- """
184
- def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
185
- super().__init__()
186
- assert isinstance(backbone, nn.Module)
187
- img_size = to_2tuple(img_size)
188
- self.img_size = img_size
189
- self.backbone = backbone
190
- if feature_size is None:
191
- with torch.no_grad():
192
- training = backbone.training
193
- if training:
194
- backbone.eval()
195
- o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
196
- feature_size = o.shape[-2:]
197
- feature_dim = o.shape[1]
198
- backbone.train(training)
199
- else:
200
- feature_size = to_2tuple(feature_size)
201
- feature_dim = self.backbone.feature_info.channels()[-1]
202
- self.num_patches = feature_size[0] * feature_size[1]
203
- self.proj = nn.Linear(feature_dim, embed_dim)
204
-
205
- def forward(self, x):
206
- x = self.backbone(x)[-1]
207
- x = x.flatten(2).transpose(1, 2)
208
- x = self.proj(x)
209
- return x
210
-
211
-
212
- class ViT(nn.Module):
213
-
214
- def __init__(self,
215
- img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
216
- num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
217
- drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
218
- frozen_stages=-1, ratio=1, last_norm=True,
219
- patch_padding='pad', freeze_attn=False, freeze_ffn=False,cfg=None,
220
- ):
221
- # Protect mutable default arguments
222
- super(ViT, self).__init__()
223
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
224
- self.num_classes = num_classes
225
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
226
- self.frozen_stages = frozen_stages
227
- self.use_checkpoint = use_checkpoint
228
- self.patch_padding = patch_padding
229
- self.freeze_attn = freeze_attn
230
- self.freeze_ffn = freeze_ffn
231
- self.depth = depth
232
-
233
- if hybrid_backbone is not None:
234
- self.patch_embed = HybridEmbed(
235
- hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
236
- else:
237
- self.patch_embed = PatchEmbed(
238
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
239
- num_patches = self.patch_embed.num_patches
240
-
241
- ##########################################
242
- self.cfg = cfg
243
- self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
244
- self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
245
- npose = self.joint_rep_dim * (cfg.MANO.NUM_HAND_JOINTS + 1)
246
- self.npose = npose
247
- mean_params = np.load(cfg.MANO.MEAN_PARAMS)
248
- init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
249
- self.register_buffer('init_cam', init_cam)
250
- init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
251
- init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
252
- self.register_buffer('init_hand_pose', init_hand_pose)
253
- self.register_buffer('init_betas', init_betas)
254
-
255
- self.pose_emb = nn.Linear(self.joint_rep_dim , embed_dim)
256
- self.shape_emb = nn.Linear(10 , embed_dim)
257
- self.cam_emb = nn.Linear(3 , embed_dim)
258
-
259
- self.decpose = nn.Linear(self.num_features, 6)
260
- self.decshape = nn.Linear(self.num_features, 10)
261
- self.deccam = nn.Linear(self.num_features, 3)
262
- if cfg.MODEL.MANO_HEAD.get('INIT_DECODER_XAVIER', False):
263
- # True by default in MLP. False by default in Transformer
264
- nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
265
- nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
266
- nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
267
-
268
-
269
- ##########################################
270
-
271
- # since the pretraining model has class token
272
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
273
-
274
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
275
-
276
- self.blocks = nn.ModuleList([
277
- Block(
278
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
279
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
280
- )
281
- for i in range(depth)])
282
-
283
- self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
284
-
285
- if self.pos_embed is not None:
286
- trunc_normal_(self.pos_embed, std=.02)
287
-
288
- self._freeze_stages()
289
-
290
- def _freeze_stages(self):
291
- """Freeze parameters."""
292
- if self.frozen_stages >= 0:
293
- self.patch_embed.eval()
294
- for param in self.patch_embed.parameters():
295
- param.requires_grad = False
296
-
297
- for i in range(1, self.frozen_stages + 1):
298
- m = self.blocks[i]
299
- m.eval()
300
- for param in m.parameters():
301
- param.requires_grad = False
302
-
303
- if self.freeze_attn:
304
- for i in range(0, self.depth):
305
- m = self.blocks[i]
306
- m.attn.eval()
307
- m.norm1.eval()
308
- for param in m.attn.parameters():
309
- param.requires_grad = False
310
- for param in m.norm1.parameters():
311
- param.requires_grad = False
312
-
313
- if self.freeze_ffn:
314
- self.pos_embed.requires_grad = False
315
- self.patch_embed.eval()
316
- for param in self.patch_embed.parameters():
317
- param.requires_grad = False
318
- for i in range(0, self.depth):
319
- m = self.blocks[i]
320
- m.mlp.eval()
321
- m.norm2.eval()
322
- for param in m.mlp.parameters():
323
- param.requires_grad = False
324
- for param in m.norm2.parameters():
325
- param.requires_grad = False
326
-
327
- def init_weights(self):
328
- """Initialize the weights in backbone.
329
- Args:
330
- pretrained (str, optional): Path to pre-trained weights.
331
- Defaults to None.
332
- """
333
- def _init_weights(m):
334
- if isinstance(m, nn.Linear):
335
- trunc_normal_(m.weight, std=.02)
336
- if isinstance(m, nn.Linear) and m.bias is not None:
337
- nn.init.constant_(m.bias, 0)
338
- elif isinstance(m, nn.LayerNorm):
339
- nn.init.constant_(m.bias, 0)
340
- nn.init.constant_(m.weight, 1.0)
341
-
342
- self.apply(_init_weights)
343
-
344
- def get_num_layers(self):
345
- return len(self.blocks)
346
-
347
- @torch.jit.ignore
348
- def no_weight_decay(self):
349
- return {'pos_embed', 'cls_token'}
350
-
351
- def forward_features(self, x):
352
- B, C, H, W = x.shape
353
- x, (Hp, Wp) = self.patch_embed(x)
354
-
355
- if self.pos_embed is not None:
356
- # fit for multiple GPU training
357
- # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
358
- x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
359
- # X [B, 192, 1280]
360
- # x cat [ mean_pose, mean_shape, mean_cam] tokens
361
- pose_tokens = self.pose_emb(self.init_hand_pose.reshape(1, self.cfg.MANO.NUM_HAND_JOINTS + 1, self.joint_rep_dim)).repeat(B, 1, 1)
362
- shape_tokens = self.shape_emb(self.init_betas).unsqueeze(1).repeat(B, 1, 1)
363
- cam_tokens = self.cam_emb(self.init_cam).unsqueeze(1).repeat(B, 1, 1)
364
-
365
- x = torch.cat([pose_tokens, shape_tokens, cam_tokens, x], 1)
366
- for blk in self.blocks:
367
- if self.use_checkpoint:
368
- x = checkpoint.checkpoint(blk, x)
369
- else:
370
- x = blk(x)
371
-
372
- x = self.last_norm(x)
373
-
374
-
375
- pose_feat = x[:, :(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
376
- shape_feat = x[:, (self.cfg.MANO.NUM_HAND_JOINTS + 1):1+(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
377
- cam_feat = x[:, 1+(self.cfg.MANO.NUM_HAND_JOINTS + 1):2+(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
378
-
379
- #print(pose_feat.shape, shape_feat.shape, cam_feat.shape)
380
- pred_hand_pose = self.decpose(pose_feat).reshape(B, -1) + self.init_hand_pose #B , 96
381
- pred_betas = self.decshape(shape_feat).reshape(B, -1) + self.init_betas #B , 10
382
- pred_cam = self.deccam(cam_feat).reshape(B, -1) + self.init_cam #B , 3
383
-
384
- pred_mano_feats = {}
385
- pred_mano_feats['hand_pose'] = pred_hand_pose
386
- pred_mano_feats['betas'] = pred_betas
387
- pred_mano_feats['cam'] = pred_cam
388
-
389
-
390
- joint_conversion_fn = {
391
- '6d': rot6d_to_rotmat,
392
- 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
393
- }[self.joint_rep_type]
394
-
395
- pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
396
- pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
397
- 'hand_pose': pred_hand_pose[:, 1:],
398
- 'betas': pred_betas}
399
-
400
- img_feat = x[:, 2+(self.cfg.MANO.NUM_HAND_JOINTS + 1):].reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2)
401
- return pred_mano_params, pred_cam, pred_mano_feats, img_feat
402
-
403
- def forward(self, x):
404
- x = self.forward_features(x)
405
- return x
406
-
407
- def train(self, mode=True):
408
- """Convert the model into training mode."""
409
- super().train(mode)
410
- self._freeze_stages()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/discriminator.py DELETED
@@ -1,98 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class Discriminator(nn.Module):
5
-
6
- def __init__(self):
7
- """
8
- Pose + Shape discriminator proposed in HMR
9
- """
10
- super(Discriminator, self).__init__()
11
-
12
- self.num_joints = 15
13
- # poses_alone
14
- self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
15
- nn.init.xavier_uniform_(self.D_conv1.weight)
16
- nn.init.zeros_(self.D_conv1.bias)
17
- self.relu = nn.ReLU(inplace=True)
18
- self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
19
- nn.init.xavier_uniform_(self.D_conv2.weight)
20
- nn.init.zeros_(self.D_conv2.bias)
21
- pose_out = []
22
- for i in range(self.num_joints):
23
- pose_out_temp = nn.Linear(32, 1)
24
- nn.init.xavier_uniform_(pose_out_temp.weight)
25
- nn.init.zeros_(pose_out_temp.bias)
26
- pose_out.append(pose_out_temp)
27
- self.pose_out = nn.ModuleList(pose_out)
28
-
29
- # betas
30
- self.betas_fc1 = nn.Linear(10, 10)
31
- nn.init.xavier_uniform_(self.betas_fc1.weight)
32
- nn.init.zeros_(self.betas_fc1.bias)
33
- self.betas_fc2 = nn.Linear(10, 5)
34
- nn.init.xavier_uniform_(self.betas_fc2.weight)
35
- nn.init.zeros_(self.betas_fc2.bias)
36
- self.betas_out = nn.Linear(5, 1)
37
- nn.init.xavier_uniform_(self.betas_out.weight)
38
- nn.init.zeros_(self.betas_out.bias)
39
-
40
- # poses_joint
41
- self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024)
42
- nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
43
- nn.init.zeros_(self.D_alljoints_fc1.bias)
44
- self.D_alljoints_fc2 = nn.Linear(1024, 1024)
45
- nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
46
- nn.init.zeros_(self.D_alljoints_fc2.bias)
47
- self.D_alljoints_out = nn.Linear(1024, 1)
48
- nn.init.xavier_uniform_(self.D_alljoints_out.weight)
49
- nn.init.zeros_(self.D_alljoints_out.bias)
50
-
51
-
52
- def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor:
53
- """
54
- Forward pass of the discriminator.
55
- Args:
56
- poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of MANO hand poses (excluding the global orientation).
57
- betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of MANO beta coefficients.
58
- Returns:
59
- torch.Tensor: Discriminator output with shape (B, 25)
60
- """
61
- #bn = poses.shape[0]
62
- # poses B x 207
63
- #poses = poses.reshape(bn, -1)
64
- # poses B x num_joints x 1 x 9
65
- poses = poses.reshape(-1, self.num_joints, 1, 9)
66
- bn = poses.shape[0]
67
- # poses B x 9 x num_joints x 1
68
- poses = poses.permute(0, 3, 1, 2).contiguous()
69
-
70
- # poses_alone
71
- poses = self.D_conv1(poses)
72
- poses = self.relu(poses)
73
- poses = self.D_conv2(poses)
74
- poses = self.relu(poses)
75
-
76
- poses_out = []
77
- for i in range(self.num_joints):
78
- poses_out_ = self.pose_out[i](poses[:, :, i, 0])
79
- poses_out.append(poses_out_)
80
- poses_out = torch.cat(poses_out, dim=1)
81
-
82
- # betas
83
- betas = self.betas_fc1(betas)
84
- betas = self.relu(betas)
85
- betas = self.betas_fc2(betas)
86
- betas = self.relu(betas)
87
- betas_out = self.betas_out(betas)
88
-
89
- # poses_joint
90
- poses = poses.reshape(bn,-1)
91
- poses_all = self.D_alljoints_fc1(poses)
92
- poses_all = self.relu(poses_all)
93
- poses_all = self.D_alljoints_fc2(poses_all)
94
- poses_all = self.relu(poses_all)
95
- poses_all_out = self.D_alljoints_out(poses_all)
96
-
97
- disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
98
- return disc_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/heads/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .refinement_net import RefineNet
 
 
WiLoR/wilor/models/heads/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (200 Bytes)
 
WiLoR/wilor/models/heads/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (219 Bytes)
 
WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-310.pyc DELETED
Binary file (7.63 kB)
 
WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-311.pyc DELETED
Binary file (15 kB)
 
WiLoR/wilor/models/heads/refinement_net.py DELETED
@@ -1,204 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
6
- from typing import Optional
7
-
8
- def make_linear_layers(feat_dims, relu_final=True, use_bn=False):
9
- layers = []
10
- for i in range(len(feat_dims)-1):
11
- layers.append(nn.Linear(feat_dims[i], feat_dims[i+1]))
12
-
13
- # Do not use ReLU for final estimation
14
- if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final):
15
- if use_bn:
16
- layers.append(nn.BatchNorm1d(feat_dims[i+1]))
17
- layers.append(nn.ReLU(inplace=True))
18
-
19
- return nn.Sequential(*layers)
20
-
21
- def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
22
- layers = []
23
- for i in range(len(feat_dims)-1):
24
- layers.append(
25
- nn.Conv2d(
26
- in_channels=feat_dims[i],
27
- out_channels=feat_dims[i+1],
28
- kernel_size=kernel,
29
- stride=stride,
30
- padding=padding
31
- ))
32
- # Do not use BN and ReLU for final estimation
33
- if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
34
- layers.append(nn.BatchNorm2d(feat_dims[i+1]))
35
- layers.append(nn.ReLU(inplace=True))
36
-
37
- return nn.Sequential(*layers)
38
-
39
- def make_deconv_layers(feat_dims, bnrelu_final=True):
40
- layers = []
41
- for i in range(len(feat_dims)-1):
42
- layers.append(
43
- nn.ConvTranspose2d(
44
- in_channels=feat_dims[i],
45
- out_channels=feat_dims[i+1],
46
- kernel_size=4,
47
- stride=2,
48
- padding=1,
49
- output_padding=0,
50
- bias=False))
51
-
52
- # Do not use BN and ReLU for final estimation
53
- if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
54
- layers.append(nn.BatchNorm2d(feat_dims[i+1]))
55
- layers.append(nn.ReLU(inplace=True))
56
-
57
- return nn.Sequential(*layers)
58
-
59
- def sample_joint_features(img_feat, joint_xy):
60
- height, width = img_feat.shape[2:]
61
- x = joint_xy[:, :, 0] / (width - 1) * 2 - 1
62
- y = joint_xy[:, :, 1] / (height - 1) * 2 - 1
63
- grid = torch.stack((x, y), 2)[:, :, None, :]
64
- img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num
65
- img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim
66
- return img_feat
67
-
68
- def perspective_projection(points: torch.Tensor,
69
- translation: torch.Tensor,
70
- focal_length: torch.Tensor,
71
- camera_center: Optional[torch.Tensor] = None,
72
- rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
73
- """
74
- Computes the perspective projection of a set of 3D points.
75
- Args:
76
- points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
77
- translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
78
- focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
79
- camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
80
- rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
81
- Returns:
82
- torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
83
- """
84
- batch_size = points.shape[0]
85
- if rotation is None:
86
- rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
87
- if camera_center is None:
88
- camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
89
- # Populate intrinsic camera matrix K.
90
- K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
91
- K[:,0,0] = focal_length[:,0]
92
- K[:,1,1] = focal_length[:,1]
93
- K[:,2,2] = 1.
94
- K[:,:-1, -1] = camera_center
95
- # Transform points
96
- points = torch.einsum('bij,bkj->bki', rotation, points)
97
- points = points + translation.unsqueeze(1)
98
-
99
- # Apply perspective distortion
100
- projected_points = points / points[:,:,-1].unsqueeze(-1)
101
-
102
- # Apply camera intrinsics
103
- projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
104
-
105
- return projected_points[:, :, :-1]
106
-
107
- class DeConvNet(nn.Module):
108
- def __init__(self, feat_dim=768, upscale=4):
109
- super(DeConvNet, self).__init__()
110
- self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
111
- self.deconv = nn.ModuleList([])
112
- for i in range(int(math.log2(upscale))+1):
113
- if i==0:
114
- self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4]))
115
- elif i==1:
116
- self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8]))
117
- elif i==2:
118
- self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8, feat_dim//8]))
119
-
120
- def forward(self, img_feat):
121
-
122
- face_img_feats = []
123
- img_feat = self.first_conv(img_feat)
124
- face_img_feats.append(img_feat)
125
- for i, deconv in enumerate(self.deconv):
126
- scale = 2**i
127
- img_feat_i = deconv(img_feat)
128
- face_img_feat = img_feat_i
129
- face_img_feats.append(face_img_feat)
130
- return face_img_feats[::-1] # high resolution -> low resolution
131
-
132
- class DeConvNet_v2(nn.Module):
133
- def __init__(self, feat_dim=768):
134
- super(DeConvNet_v2, self).__init__()
135
- self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
136
- self.deconv = nn.Sequential(*[nn.ConvTranspose2d(in_channels=feat_dim//2, out_channels=feat_dim//4, kernel_size=4, stride=4, padding=0, output_padding=0, bias=False),
137
- nn.BatchNorm2d(feat_dim//4),
138
- nn.ReLU(inplace=True)])
139
-
140
- def forward(self, img_feat):
141
-
142
- face_img_feats = []
143
- img_feat = self.first_conv(img_feat)
144
- img_feat = self.deconv(img_feat)
145
-
146
- return [img_feat]
147
-
148
- class RefineNet(nn.Module):
149
- def __init__(self, cfg, feat_dim=1280, upscale=3):
150
- super(RefineNet, self).__init__()
151
- #self.deconv = DeConvNet_v2(feat_dim=feat_dim)
152
- #self.out_dim = feat_dim//4
153
-
154
- self.deconv = DeConvNet(feat_dim=feat_dim, upscale=upscale)
155
- self.out_dim = feat_dim//8 + feat_dim//4 + feat_dim//2
156
- self.dec_pose = nn.Linear(self.out_dim, 96)
157
- self.dec_cam = nn.Linear(self.out_dim, 3)
158
- self.dec_shape = nn.Linear(self.out_dim, 10)
159
-
160
- self.cfg = cfg
161
- self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
162
- self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
163
-
164
- def forward(self, img_feat, verts_3d, pred_cam, pred_mano_feats, focal_length):
165
- B = img_feat.shape[0]
166
-
167
- img_feats = self.deconv(img_feat)
168
-
169
- img_feat_sizes = [img_feat.shape[2] for img_feat in img_feats]
170
-
171
- temp_cams = [torch.stack([pred_cam[:, 1], pred_cam[:, 2],
172
- 2*focal_length[:, 0]/(img_feat_size * pred_cam[:, 0] +1e-9)],dim=-1) for img_feat_size in img_feat_sizes]
173
-
174
- verts_2d = [perspective_projection(verts_3d,
175
- translation=temp_cams[i],
176
- focal_length=focal_length / img_feat_sizes[i]) for i in range(len(img_feat_sizes))]
177
-
178
- vert_feats = [sample_joint_features(img_feats[i], verts_2d[i]).max(1).values for i in range(len(img_feat_sizes))]
179
-
180
- vert_feats = torch.cat(vert_feats, dim=-1)
181
-
182
- delta_pose = self.dec_pose(vert_feats)
183
- delta_betas = self.dec_shape(vert_feats)
184
- delta_cam = self.dec_cam(vert_feats)
185
-
186
-
187
- pred_hand_pose = pred_mano_feats['hand_pose'] + delta_pose
188
- pred_betas = pred_mano_feats['betas'] + delta_betas
189
- pred_cam = pred_mano_feats['cam'] + delta_cam
190
-
191
- joint_conversion_fn = {
192
- '6d': rot6d_to_rotmat,
193
- 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
194
- }[self.joint_rep_type]
195
-
196
- pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
197
-
198
- pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
199
- 'hand_pose': pred_hand_pose[:, 1:],
200
- 'betas': pred_betas}
201
-
202
- return pred_mano_params, pred_cam
203
-
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/losses.py DELETED
@@ -1,92 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class Keypoint2DLoss(nn.Module):
5
-
6
- def __init__(self, loss_type: str = 'l1'):
7
- """
8
- 2D keypoint loss module.
9
- Args:
10
- loss_type (str): Choose between l1 and l2 losses.
11
- """
12
- super(Keypoint2DLoss, self).__init__()
13
- if loss_type == 'l1':
14
- self.loss_fn = nn.L1Loss(reduction='none')
15
- elif loss_type == 'l2':
16
- self.loss_fn = nn.MSELoss(reduction='none')
17
- else:
18
- raise NotImplementedError('Unsupported loss function')
19
-
20
- def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
21
- """
22
- Compute 2D reprojection loss on the keypoints.
23
- Args:
24
- pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
25
- gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
26
- Returns:
27
- torch.Tensor: 2D keypoint loss.
28
- """
29
- conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
30
- batch_size = conf.shape[0]
31
- loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2))
32
- return loss.sum()
33
-
34
-
35
- class Keypoint3DLoss(nn.Module):
36
-
37
- def __init__(self, loss_type: str = 'l1'):
38
- """
39
- 3D keypoint loss module.
40
- Args:
41
- loss_type (str): Choose between l1 and l2 losses.
42
- """
43
- super(Keypoint3DLoss, self).__init__()
44
- if loss_type == 'l1':
45
- self.loss_fn = nn.L1Loss(reduction='none')
46
- elif loss_type == 'l2':
47
- self.loss_fn = nn.MSELoss(reduction='none')
48
- else:
49
- raise NotImplementedError('Unsupported loss function')
50
-
51
- def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0):
52
- """
53
- Compute 3D keypoint loss.
54
- Args:
55
- pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
56
- gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
57
- Returns:
58
- torch.Tensor: 3D keypoint loss.
59
- """
60
- batch_size = pred_keypoints_3d.shape[0]
61
- gt_keypoints_3d = gt_keypoints_3d.clone()
62
- pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
63
- gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
64
- conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
65
- gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
66
- loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2))
67
- return loss.sum()
68
-
69
- class ParameterLoss(nn.Module):
70
-
71
- def __init__(self):
72
- """
73
- MANO parameter loss module.
74
- """
75
- super(ParameterLoss, self).__init__()
76
- self.loss_fn = nn.MSELoss(reduction='none')
77
-
78
- def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
79
- """
80
- Compute MANO parameter loss.
81
- Args:
82
- pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
83
- gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters.
84
- Returns:
85
- torch.Tensor: L2 parameter loss loss.
86
- """
87
- batch_size = pred_param.shape[0]
88
- num_dims = len(pred_param.shape)
89
- mask_dimension = [batch_size] + [1] * (num_dims-1)
90
- has_param = has_param.type(pred_param.type()).view(*mask_dimension)
91
- loss_param = (has_param * self.loss_fn(pred_param, gt_param))
92
- return loss_param.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/mano_wrapper.py DELETED
@@ -1,40 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import pickle
4
- from typing import Optional
5
- import smplx
6
- from smplx.lbs import vertices2joints
7
- from smplx.utils import MANOOutput, to_tensor
8
- from smplx.vertex_ids import vertex_ids
9
-
10
-
11
- class MANO(smplx.MANOLayer):
12
- def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
13
- """
14
- Extension of the official MANO implementation to support more joints.
15
- Args:
16
- Same as MANOLayer.
17
- joint_regressor_extra (str): Path to extra joint regressor.
18
- """
19
- super(MANO, self).__init__(*args, **kwargs)
20
- mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
21
-
22
- #2, 3, 5, 4, 1
23
- if joint_regressor_extra is not None:
24
- self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
25
- self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
26
- self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
27
-
28
- def forward(self, *args, **kwargs) -> MANOOutput:
29
- """
30
- Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
31
- """
32
- mano_output = super(MANO, self).forward(*args, **kwargs)
33
- extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
34
- joints = torch.cat([mano_output.joints, extra_joints], dim=1)
35
- joints = joints[:, self.joint_map, :]
36
- if hasattr(self, 'joint_regressor_extra'):
37
- extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
38
- joints = torch.cat([joints, extra_joints], dim=1)
39
- mano_output.joints = joints
40
- return mano_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/models/wilor.py DELETED
@@ -1,376 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- from typing import Any, Dict, Mapping, Tuple
4
-
5
- from yacs.config import CfgNode
6
-
7
- from ..utils import SkeletonRenderer, MeshRenderer
8
- from ..utils.geometry import aa_to_rotmat, perspective_projection
9
- from ..utils.pylogger import get_pylogger
10
- from .backbones import create_backbone
11
- from .heads import RefineNet
12
- from .discriminator import Discriminator
13
- from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss
14
- from . import MANO
15
-
16
- log = get_pylogger(__name__)
17
-
18
- class WiLoR(pl.LightningModule):
19
-
20
- def __init__(self, cfg: CfgNode, init_renderer: bool = True):
21
- """
22
- Setup WiLoR model
23
- Args:
24
- cfg (CfgNode): Config file as a yacs CfgNode
25
- """
26
- super().__init__()
27
-
28
- # Save hyperparameters
29
- self.save_hyperparameters(logger=False, ignore=['init_renderer'])
30
-
31
- self.cfg = cfg
32
- # Create backbone feature extractor
33
- self.backbone = create_backbone(cfg)
34
- if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
35
- log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
36
- self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False)
37
-
38
- # Create RefineNet head
39
- self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3)
40
-
41
- # Create discriminator
42
- if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
43
- self.discriminator = Discriminator()
44
-
45
- # Define loss functions
46
- self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
47
- self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
48
- self.mano_parameter_loss = ParameterLoss()
49
-
50
- # Instantiate MANO model
51
- mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
52
- self.mano = MANO(**mano_cfg)
53
-
54
- # Buffer that shows whetheer we need to initialize ActNorm layers
55
- self.register_buffer('initialized', torch.tensor(False))
56
- # Setup renderer for visualization
57
- if init_renderer:
58
- self.renderer = SkeletonRenderer(self.cfg)
59
- self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces)
60
- else:
61
- self.renderer = None
62
- self.mesh_renderer = None
63
-
64
-
65
- # Disable automatic optimization since we use adversarial training
66
- self.automatic_optimization = False
67
-
68
- def on_after_backward(self):
69
- for name, param in self.named_parameters():
70
- if param.grad is None:
71
- print(param.shape)
72
- print(name)
73
-
74
-
75
- def get_parameters(self):
76
- #all_params = list(self.mano_head.parameters())
77
- all_params = list(self.backbone.parameters())
78
- return all_params
79
-
80
- def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
81
- """
82
- Setup model and distriminator Optimizers
83
- Returns:
84
- Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
85
- """
86
- param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
87
-
88
- optimizer = torch.optim.AdamW(params=param_groups,
89
- # lr=self.cfg.TRAIN.LR,
90
- weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
91
- optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
92
- lr=self.cfg.TRAIN.LR,
93
- weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
94
-
95
- return optimizer, optimizer_disc
96
-
97
- def forward_step(self, batch: Dict, train: bool = False) -> Dict:
98
- """
99
- Run a forward step of the network
100
- Args:
101
- batch (Dict): Dictionary containing batch data
102
- train (bool): Flag indicating whether it is training or validation mode
103
- Returns:
104
- Dict: Dictionary containing the regression output
105
- """
106
- # Use RGB image as input
107
- x = batch['img']
108
- batch_size = x.shape[0]
109
- # Compute conditioning features using the backbone
110
- # if using ViT backbone, we need to use a different aspect ratio
111
- temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12
112
-
113
-
114
- # Compute camera translation
115
- device = temp_mano_params['hand_pose'].device
116
- dtype = temp_mano_params['hand_pose'].dtype
117
- focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
118
-
119
-
120
- ## Temp MANO
121
- temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
122
- temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
123
- temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1)
124
- temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False)
125
- #temp_keypoints_3d = temp_mano_output.joints
126
- temp_vertices = temp_mano_output.vertices
127
-
128
- pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length)
129
- # Store useful regression outputs to the output dict
130
-
131
-
132
- output = {}
133
- output['pred_cam'] = pred_cam
134
- output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()}
135
-
136
- pred_cam_t = torch.stack([pred_cam[:, 1],
137
- pred_cam[:, 2],
138
- 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
139
- output['pred_cam_t'] = pred_cam_t
140
- output['focal_length'] = focal_length
141
-
142
- # Compute model vertices, joints and the projected joints
143
- pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
144
- pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
145
- pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1)
146
- mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False)
147
- pred_keypoints_3d = mano_output.joints
148
- pred_vertices = mano_output.vertices
149
-
150
- output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
151
- output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
152
- pred_cam_t = pred_cam_t.reshape(-1, 3)
153
- focal_length = focal_length.reshape(-1, 2)
154
-
155
- pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
156
- translation=pred_cam_t,
157
- focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
158
- output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
159
-
160
- return output
161
-
162
- def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
163
- """
164
- Compute losses given the input batch and the regression output
165
- Args:
166
- batch (Dict): Dictionary containing batch data
167
- output (Dict): Dictionary containing the regression output
168
- train (bool): Flag indicating whether it is training or validation mode
169
- Returns:
170
- torch.Tensor : Total loss for current batch
171
- """
172
-
173
- pred_mano_params = output['pred_mano_params']
174
- pred_keypoints_2d = output['pred_keypoints_2d']
175
- pred_keypoints_3d = output['pred_keypoints_3d']
176
-
177
-
178
- batch_size = pred_mano_params['hand_pose'].shape[0]
179
- device = pred_mano_params['hand_pose'].device
180
- dtype = pred_mano_params['hand_pose'].dtype
181
-
182
- # Get annotations
183
- gt_keypoints_2d = batch['keypoints_2d']
184
- gt_keypoints_3d = batch['keypoints_3d']
185
- gt_mano_params = batch['mano_params']
186
- has_mano_params = batch['has_mano_params']
187
- is_axis_angle = batch['mano_params_is_axis_angle']
188
-
189
- # Compute 3D keypoint loss
190
- loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
191
- loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
192
-
193
- # Compute loss on MANO parameters
194
- loss_mano_params = {}
195
- for k, pred in pred_mano_params.items():
196
- gt = gt_mano_params[k].view(batch_size, -1)
197
- if is_axis_angle[k].all():
198
- gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
199
- has_gt = has_mano_params[k]
200
- loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt)
201
-
202
- loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
203
- self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
204
- sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
205
-
206
-
207
- losses = dict(loss=loss.detach(),
208
- loss_keypoints_2d=loss_keypoints_2d.detach(),
209
- loss_keypoints_3d=loss_keypoints_3d.detach())
210
-
211
- for k, v in loss_mano_params.items():
212
- losses['loss_' + k] = v.detach()
213
-
214
- output['losses'] = losses
215
-
216
- return loss
217
-
218
- # Tensoroboard logging should run from first rank only
219
- @pl.utilities.rank_zero.rank_zero_only
220
- def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None:
221
- """
222
- Log results to Tensorboard
223
- Args:
224
- batch (Dict): Dictionary containing batch data
225
- output (Dict): Dictionary containing the regression output
226
- step_count (int): Global training step count
227
- train (bool): Flag indicating whether it is training or validation mode
228
- """
229
-
230
- mode = 'train' if train else 'val'
231
- batch_size = batch['keypoints_2d'].shape[0]
232
- images = batch['img']
233
- images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
234
- images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
235
- #images = 255*images.permute(0, 2, 3, 1).cpu().numpy()
236
-
237
- pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
238
- pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
239
- focal_length = output['focal_length'].detach().reshape(batch_size, 2)
240
- gt_keypoints_3d = batch['keypoints_3d']
241
- gt_keypoints_2d = batch['keypoints_2d']
242
-
243
- losses = output['losses']
244
- pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
245
- pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
246
- if write_to_summary_writer:
247
- summary_writer = self.logger.experiment
248
- for loss_name, val in losses.items():
249
- summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
250
- num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
251
-
252
- gt_keypoints_3d = batch['keypoints_3d']
253
- pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
254
-
255
- # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow.
256
- #predictions = self.renderer(pred_keypoints_3d[:num_images],
257
- # gt_keypoints_3d[:num_images],
258
- # 2 * gt_keypoints_2d[:num_images],
259
- # images=images[:num_images],
260
- # camera_translation=pred_cam_t[:num_images])
261
- predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
262
- pred_cam_t[:num_images].cpu().numpy(),
263
- images[:num_images].cpu().numpy(),
264
- pred_keypoints_2d[:num_images].cpu().numpy(),
265
- gt_keypoints_2d[:num_images].cpu().numpy(),
266
- focal_length=focal_length[:num_images].cpu().numpy())
267
- if write_to_summary_writer:
268
- summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
269
-
270
- return predictions
271
-
272
- def forward(self, batch: Dict) -> Dict:
273
- """
274
- Run a forward step of the network in val mode
275
- Args:
276
- batch (Dict): Dictionary containing batch data
277
- Returns:
278
- Dict: Dictionary containing the regression output
279
- """
280
- return self.forward_step(batch, train=False)
281
-
282
- def training_step_discriminator(self, batch: Dict,
283
- hand_pose: torch.Tensor,
284
- betas: torch.Tensor,
285
- optimizer: torch.optim.Optimizer) -> torch.Tensor:
286
- """
287
- Run a discriminator training step
288
- Args:
289
- batch (Dict): Dictionary containing mocap batch data
290
- hand_pose (torch.Tensor): Regressed hand pose from current step
291
- betas (torch.Tensor): Regressed betas from current step
292
- optimizer (torch.optim.Optimizer): Discriminator optimizer
293
- Returns:
294
- torch.Tensor: Discriminator loss
295
- """
296
- batch_size = hand_pose.shape[0]
297
- gt_hand_pose = batch['hand_pose']
298
- gt_betas = batch['betas']
299
- gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3)
300
- disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach())
301
- loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
302
- disc_real_out = self.discriminator(gt_rotmat, gt_betas)
303
- loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
304
- loss_disc = loss_fake + loss_real
305
- loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
306
- optimizer.zero_grad()
307
- self.manual_backward(loss)
308
- optimizer.step()
309
- return loss_disc.detach()
310
-
311
- def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
312
- """
313
- Run a full training step
314
- Args:
315
- joint_batch (Dict): Dictionary containing image and mocap batch data
316
- batch_idx (int): Unused.
317
- batch_idx (torch.Tensor): Unused.
318
- Returns:
319
- Dict: Dictionary containing regression output.
320
- """
321
- batch = joint_batch['img']
322
- mocap_batch = joint_batch['mocap']
323
- optimizer = self.optimizers(use_pl_optimizer=True)
324
- if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
325
- optimizer, optimizer_disc = optimizer
326
-
327
- batch_size = batch['img'].shape[0]
328
- output = self.forward_step(batch, train=True)
329
- pred_mano_params = output['pred_mano_params']
330
- if self.cfg.get('UPDATE_GT_SPIN', False):
331
- self.update_batch_gt_spin(batch, output)
332
- loss = self.compute_loss(batch, output, train=True)
333
- if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
334
- disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1))
335
- loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
336
- loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
337
-
338
- # Error if Nan
339
- if torch.isnan(loss):
340
- raise ValueError('Loss is NaN')
341
-
342
- optimizer.zero_grad()
343
- self.manual_backward(loss)
344
- # Clip gradient
345
- if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
346
- gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
347
- self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True)
348
- optimizer.step()
349
- if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
350
- loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc)
351
- output['losses']['loss_gen'] = loss_adv
352
- output['losses']['loss_disc'] = loss_disc
353
-
354
- if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
355
- self.tensorboard_logging(batch, output, self.global_step, train=True)
356
-
357
- self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False)
358
-
359
- return output
360
-
361
- def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
362
- """
363
- Run a validation step and log to Tensorboard
364
- Args:
365
- batch (Dict): Dictionary containing batch data
366
- batch_idx (int): Unused.
367
- Returns:
368
- Dict: Dictionary containing regression output.
369
- """
370
- # batch_size = batch['img'].shape[0]
371
- output = self.forward_step(batch, train=False)
372
- loss = self.compute_loss(batch, output, train=False)
373
- output['loss'] = loss
374
- self.tensorboard_logging(batch, output, self.global_step, train=False)
375
-
376
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/utils/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- import torch
2
- from typing import Any
3
-
4
- from .renderer import Renderer
5
- from .mesh_renderer import MeshRenderer
6
- from .skeleton_renderer import SkeletonRenderer
7
- from .pose_utils import eval_pose, Evaluator
8
-
9
- def recursive_to(x: Any, target: torch.device):
10
- """
11
- Recursively transfer a batch of data to the target device
12
- Args:
13
- x (Any): Batch of data.
14
- target (torch.device): Target device.
15
- Returns:
16
- Batch of data where all tensors are transfered to the target device.
17
- """
18
- if isinstance(x, dict):
19
- return {k: recursive_to(v, target) for k, v in x.items()}
20
- elif isinstance(x, torch.Tensor):
21
- return x.to(target)
22
- elif isinstance(x, list):
23
- return [recursive_to(i, target) for i in x]
24
- else:
25
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
WiLoR/wilor/utils/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.84 kB)
 
WiLoR/wilor/utils/__pycache__/geometry.cpython-311.pyc DELETED
Binary file (6.73 kB)