thaint2901 commited on
Commit
f3261a0
·
1 Parent(s): f5f446c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. TPSMM/LICENSE +21 -0
  2. TPSMM/README.md +98 -0
  3. TPSMM/assets/source.png +0 -0
  4. TPSMM/assets/source1.png +0 -0
  5. TPSMM/assets/source2.jpg +0 -0
  6. TPSMM/augmentation.py +344 -0
  7. TPSMM/cog.yaml +40 -0
  8. TPSMM/config/mgif-256.yaml +75 -0
  9. TPSMM/config/taichi-256.yaml +134 -0
  10. TPSMM/config/ted-384.yaml +73 -0
  11. TPSMM/config/vox-256.yaml +74 -0
  12. TPSMM/demo.ipynb +0 -0
  13. TPSMM/demo.py +180 -0
  14. TPSMM/frames_dataset.py +173 -0
  15. TPSMM/logger.py +212 -0
  16. TPSMM/modules/avd_network.py +65 -0
  17. TPSMM/modules/bg_motion_predictor.py +24 -0
  18. TPSMM/modules/dense_motion.py +164 -0
  19. TPSMM/modules/inpainting_network.py +127 -0
  20. TPSMM/modules/keypoint_detector.py +27 -0
  21. TPSMM/modules/model.py +182 -0
  22. TPSMM/modules/util.py +349 -0
  23. TPSMM/pkgs/tpsmm.py +80 -0
  24. TPSMM/predict.py +125 -0
  25. TPSMM/pretrained/vox.pth.tar +3 -0
  26. TPSMM/reconstruction.py +69 -0
  27. TPSMM/requirements.txt +25 -0
  28. TPSMM/run.py +89 -0
  29. TPSMM/tmp.jpg +0 -0
  30. TPSMM/tmp.py +14 -0
  31. TPSMM/train.py +94 -0
  32. TPSMM/train_avd.py +91 -0
  33. app.py +122 -0
  34. assets/0.jpg +0 -0
  35. assets/1.jpg +0 -0
  36. assets/2.jpg +0 -0
  37. assets/3.jpg +0 -0
  38. requirements.txt +12 -0
  39. retinaface/change_batch_onnx.py +43 -0
  40. retinaface/convert_to_onnx.py +135 -0
  41. retinaface/data/__init__.py +3 -0
  42. retinaface/data/config.py +55 -0
  43. retinaface/data/data_augment.py +235 -0
  44. retinaface/data/wider_face.py +101 -0
  45. retinaface/detect.py +152 -0
  46. retinaface/detect_video_raw.py +66 -0
  47. retinaface/layers/__init__.py +2 -0
  48. retinaface/layers/functions/prior_box.py +33 -0
  49. retinaface/layers/modules/__init__.py +3 -0
  50. retinaface/layers/modules/multibox_loss.py +125 -0
TPSMM/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 yoyo-nb
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
TPSMM/README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [CVPR2022] Thin-Plate Spline Motion Model for Image Animation
2
+
3
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
4
+ ![stars](https://img.shields.io/github/stars/yoyo-nb/Thin-Plate-Spline-Motion-Model.svg?style=flat)
5
+ ![GitHub repo size](https://img.shields.io/github/repo-size/yoyo-nb/Thin-Plate-Spline-Motion-Model.svg)
6
+
7
+ Source code of the CVPR'2022 paper "Thin-Plate Spline Motion Model for Image Animation"
8
+
9
+ [**Paper**](https://arxiv.org/abs/2203.14367) **|** [**Supp**](https://cloud.tsinghua.edu.cn/f/f7b8573bb5b04583949f/?dl=1)
10
+
11
+ ### Example animation
12
+
13
+ ![vox](assets/vox.gif)
14
+ ![ted](assets/ted.gif)
15
+
16
+ **PS**: The paper trains the model for 100 epochs for a fair comparison. You can use more data and train for more epochs to get better performance.
17
+
18
+
19
+ ### Web demo for animation
20
+ - Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model)
21
+ - Try the web demo for animation here: [![Replicate](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model/badge)](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model)
22
+ - Google Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DREfdpnaBhqISg0fuQlAAIwyGVn1loH_?usp=sharing)
23
+
24
+ ### Pre-trained models
25
+ - [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/30ab8765da364fefa101/)
26
+ - [Google Drive](https://drive.google.com/drive/folders/1pNDo1ODQIb5HVObRtCmubqJikmR7VVLT?usp=sharing)
27
+
28
+ ### Installation
29
+
30
+ We support ```python3```.(Recommended version is Python 3.9).
31
+ To install the dependencies run:
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+
37
+ ### YAML configs
38
+
39
+ There are several configuration files one for each `dataset` in the `config` folder named as ```config/dataset_name.yaml```.
40
+
41
+ See description of the parameters in the ```config/taichi-256.yaml```.
42
+
43
+ ### Datasets
44
+
45
+ 1) **MGif**. Follow [Monkey-Net](https://github.com/AliaksandrSiarohin/monkey-net).
46
+
47
+ 2) **TaiChiHD** and **VoxCeleb**. Follow instructions from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
48
+
49
+ 3) **TED-talks**. Follow instructions from [MRAA](https://github.com/snap-research/articulated-animation).
50
+
51
+
52
+ ### Training
53
+ To train a model on specific dataset run:
54
+ ```
55
+ CUDA_VISIBLE_DEVICES=0,1 python run.py --config config/dataset_name.yaml --device_ids 0,1
56
+ ```
57
+ A log folder named after the timestamp will be created. Checkpoints, loss values, reconstruction results will be saved to this folder.
58
+
59
+
60
+ #### Training AVD network
61
+ To train a model on specific dataset run:
62
+ ```
63
+ CUDA_VISIBLE_DEVICES=0 python run.py --mode train_avd --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' --config config/dataset_name.yaml
64
+ ```
65
+ Checkpoints, loss values, reconstruction results will be saved to `{checkpoint_folder}`.
66
+
67
+
68
+
69
+ ### Evaluation on video reconstruction
70
+
71
+ To evaluate the reconstruction performance run:
72
+ ```
73
+ CUDA_VISIBLE_DEVICES=0 python run.py --mode reconstruction --config config/dataset_name.yaml --checkpoint '{checkpoint_folder}/checkpoint.pth.tar'
74
+ ```
75
+ The `reconstruction` subfolder will be created in `{checkpoint_folder}`.
76
+ The generated video will be stored to this folder, also generated videos will be stored in ```png``` subfolder in loss-less '.png' format for evaluation.
77
+ To compute metrics, follow instructions from [pose-evaluation](https://github.com/AliaksandrSiarohin/pose-evaluation).
78
+
79
+
80
+ ### Image animation demo
81
+ - notebook: `demo.ipynb`, edit the config cell and run for image animation.
82
+ - python:
83
+ ```bash
84
+ CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-256.yaml --checkpoint checkpoints/vox.pth.tar --source_image ./source.jpg --driving_video ./driving.mp4
85
+ ```
86
+
87
+ # Acknowledgments
88
+ The main code is based upon [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) and [MRAA](https://github.com/snap-research/articulated-animation)
89
+
90
+ Thanks for the excellent works!
91
+
92
+ And Thanks to:
93
+
94
+ - [@chenxwh](https://github.com/chenxwh): Add Web Demo & Docker environment [![Replicate](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model/badge)](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model)
95
+
96
+ - [@TalkUHulk](https://github.com/TalkUHulk): The C++/Python demo is provided in [Image-Animation-Turbo-Boost](https://github.com/TalkUHulk/Image-Animation-Turbo-Boost)
97
+
98
+ - [@AK391](https://github.com/AK391): Add huggingface web demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model)
TPSMM/assets/source.png ADDED
TPSMM/assets/source1.png ADDED
TPSMM/assets/source2.jpg ADDED
TPSMM/augmentation.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from https://github.com/hassony2/torch_videovision
3
+ """
4
+
5
+ import numbers
6
+
7
+ import random
8
+ import numpy as np
9
+ import PIL
10
+
11
+ from skimage.transform import resize, rotate
12
+ import torchvision
13
+
14
+ import warnings
15
+
16
+ from skimage import img_as_ubyte, img_as_float
17
+
18
+
19
+ def crop_clip(clip, min_h, min_w, h, w):
20
+ if isinstance(clip[0], np.ndarray):
21
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
22
+
23
+ elif isinstance(clip[0], PIL.Image.Image):
24
+ cropped = [
25
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
26
+ ]
27
+ else:
28
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
29
+ 'but got list of {0}'.format(type(clip[0])))
30
+ return cropped
31
+
32
+
33
+ def pad_clip(clip, h, w):
34
+ im_h, im_w = clip[0].shape[:2]
35
+ pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
36
+ pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
37
+
38
+ return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
39
+
40
+
41
+ def resize_clip(clip, size, interpolation='bilinear'):
42
+ if isinstance(clip[0], np.ndarray):
43
+ if isinstance(size, numbers.Number):
44
+ im_h, im_w, im_c = clip[0].shape
45
+ # Min spatial dim already matches minimal size
46
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
47
+ and im_h == size):
48
+ return clip
49
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
50
+ size = (new_w, new_h)
51
+ else:
52
+ size = size[1], size[0]
53
+
54
+ scaled = [
55
+ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
56
+ mode='constant', anti_aliasing=True) for img in clip
57
+ ]
58
+ elif isinstance(clip[0], PIL.Image.Image):
59
+ if isinstance(size, numbers.Number):
60
+ im_w, im_h = clip[0].size
61
+ # Min spatial dim already matches minimal size
62
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
63
+ and im_h == size):
64
+ return clip
65
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
66
+ size = (new_w, new_h)
67
+ else:
68
+ size = size[1], size[0]
69
+ if interpolation == 'bilinear':
70
+ pil_inter = PIL.Image.NEAREST
71
+ else:
72
+ pil_inter = PIL.Image.BILINEAR
73
+ scaled = [img.resize(size, pil_inter) for img in clip]
74
+ else:
75
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
76
+ 'but got list of {0}'.format(type(clip[0])))
77
+ return scaled
78
+
79
+
80
+ def get_resize_sizes(im_h, im_w, size):
81
+ if im_w < im_h:
82
+ ow = size
83
+ oh = int(size * im_h / im_w)
84
+ else:
85
+ oh = size
86
+ ow = int(size * im_w / im_h)
87
+ return oh, ow
88
+
89
+
90
+ class RandomFlip(object):
91
+ def __init__(self, time_flip=False, horizontal_flip=False):
92
+ self.time_flip = time_flip
93
+ self.horizontal_flip = horizontal_flip
94
+
95
+ def __call__(self, clip):
96
+ if random.random() < 0.5 and self.time_flip:
97
+ return clip[::-1]
98
+ if random.random() < 0.5 and self.horizontal_flip:
99
+ return [np.fliplr(img) for img in clip]
100
+
101
+ return clip
102
+
103
+
104
+ class RandomResize(object):
105
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
106
+ The larger the original image is, the more times it takes to
107
+ interpolate
108
+ Args:
109
+ interpolation (str): Can be one of 'nearest', 'bilinear'
110
+ defaults to nearest
111
+ size (tuple): (widht, height)
112
+ """
113
+
114
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
115
+ self.ratio = ratio
116
+ self.interpolation = interpolation
117
+
118
+ def __call__(self, clip):
119
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
120
+
121
+ if isinstance(clip[0], np.ndarray):
122
+ im_h, im_w, im_c = clip[0].shape
123
+ elif isinstance(clip[0], PIL.Image.Image):
124
+ im_w, im_h = clip[0].size
125
+
126
+ new_w = int(im_w * scaling_factor)
127
+ new_h = int(im_h * scaling_factor)
128
+ new_size = (new_w, new_h)
129
+ resized = resize_clip(
130
+ clip, new_size, interpolation=self.interpolation)
131
+
132
+ return resized
133
+
134
+
135
+ class RandomCrop(object):
136
+ """Extract random crop at the same location for a list of videos
137
+ Args:
138
+ size (sequence or int): Desired output size for the
139
+ crop in format (h, w)
140
+ """
141
+
142
+ def __init__(self, size):
143
+ if isinstance(size, numbers.Number):
144
+ size = (size, size)
145
+
146
+ self.size = size
147
+
148
+ def __call__(self, clip):
149
+ """
150
+ Args:
151
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
152
+ in format (h, w, c) in numpy.ndarray
153
+ Returns:
154
+ PIL.Image or numpy.ndarray: Cropped list of videos
155
+ """
156
+ h, w = self.size
157
+ if isinstance(clip[0], np.ndarray):
158
+ im_h, im_w, im_c = clip[0].shape
159
+ elif isinstance(clip[0], PIL.Image.Image):
160
+ im_w, im_h = clip[0].size
161
+ else:
162
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
163
+ 'but got list of {0}'.format(type(clip[0])))
164
+
165
+ clip = pad_clip(clip, h, w)
166
+ im_h, im_w = clip.shape[1:3]
167
+ x1 = 0 if h == im_h else random.randint(0, im_w - w)
168
+ y1 = 0 if w == im_w else random.randint(0, im_h - h)
169
+ cropped = crop_clip(clip, y1, x1, h, w)
170
+
171
+ return cropped
172
+
173
+
174
+ class RandomRotation(object):
175
+ """Rotate entire clip randomly by a random angle within
176
+ given bounds
177
+ Args:
178
+ degrees (sequence or int): Range of degrees to select from
179
+ If degrees is a number instead of sequence like (min, max),
180
+ the range of degrees, will be (-degrees, +degrees).
181
+ """
182
+
183
+ def __init__(self, degrees):
184
+ if isinstance(degrees, numbers.Number):
185
+ if degrees < 0:
186
+ raise ValueError('If degrees is a single number,'
187
+ 'must be positive')
188
+ degrees = (-degrees, degrees)
189
+ else:
190
+ if len(degrees) != 2:
191
+ raise ValueError('If degrees is a sequence,'
192
+ 'it must be of len 2.')
193
+
194
+ self.degrees = degrees
195
+
196
+ def __call__(self, clip):
197
+ """
198
+ Args:
199
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
200
+ in format (h, w, c) in numpy.ndarray
201
+ Returns:
202
+ PIL.Image or numpy.ndarray: Cropped list of videos
203
+ """
204
+ angle = random.uniform(self.degrees[0], self.degrees[1])
205
+ if isinstance(clip[0], np.ndarray):
206
+ rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
207
+ elif isinstance(clip[0], PIL.Image.Image):
208
+ rotated = [img.rotate(angle) for img in clip]
209
+ else:
210
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
211
+ 'but got list of {0}'.format(type(clip[0])))
212
+
213
+ return rotated
214
+
215
+
216
+ class ColorJitter(object):
217
+ """Randomly change the brightness, contrast and saturation and hue of the clip
218
+ Args:
219
+ brightness (float): How much to jitter brightness. brightness_factor
220
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
221
+ contrast (float): How much to jitter contrast. contrast_factor
222
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
223
+ saturation (float): How much to jitter saturation. saturation_factor
224
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
225
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
226
+ [-hue, hue]. Should be >=0 and <= 0.5.
227
+ """
228
+
229
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
230
+ self.brightness = brightness
231
+ self.contrast = contrast
232
+ self.saturation = saturation
233
+ self.hue = hue
234
+
235
+ def get_params(self, brightness, contrast, saturation, hue):
236
+ if brightness > 0:
237
+ brightness_factor = random.uniform(
238
+ max(0, 1 - brightness), 1 + brightness)
239
+ else:
240
+ brightness_factor = None
241
+
242
+ if contrast > 0:
243
+ contrast_factor = random.uniform(
244
+ max(0, 1 - contrast), 1 + contrast)
245
+ else:
246
+ contrast_factor = None
247
+
248
+ if saturation > 0:
249
+ saturation_factor = random.uniform(
250
+ max(0, 1 - saturation), 1 + saturation)
251
+ else:
252
+ saturation_factor = None
253
+
254
+ if hue > 0:
255
+ hue_factor = random.uniform(-hue, hue)
256
+ else:
257
+ hue_factor = None
258
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
259
+
260
+ def __call__(self, clip):
261
+ """
262
+ Args:
263
+ clip (list): list of PIL.Image
264
+ Returns:
265
+ list PIL.Image : list of transformed PIL.Image
266
+ """
267
+ if isinstance(clip[0], np.ndarray):
268
+ brightness, contrast, saturation, hue = self.get_params(
269
+ self.brightness, self.contrast, self.saturation, self.hue)
270
+
271
+ # Create img transform function sequence
272
+ img_transforms = []
273
+ if brightness is not None:
274
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
275
+ if saturation is not None:
276
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
277
+ if hue is not None:
278
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
279
+ if contrast is not None:
280
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
281
+ random.shuffle(img_transforms)
282
+ img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
283
+ img_as_float]
284
+
285
+ with warnings.catch_warnings():
286
+ warnings.simplefilter("ignore")
287
+ jittered_clip = []
288
+ for img in clip:
289
+ jittered_img = img
290
+ for func in img_transforms:
291
+ jittered_img = func(jittered_img)
292
+ jittered_clip.append(jittered_img.astype('float32'))
293
+ elif isinstance(clip[0], PIL.Image.Image):
294
+ brightness, contrast, saturation, hue = self.get_params(
295
+ self.brightness, self.contrast, self.saturation, self.hue)
296
+
297
+ # Create img transform function sequence
298
+ img_transforms = []
299
+ if brightness is not None:
300
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
301
+ if saturation is not None:
302
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
303
+ if hue is not None:
304
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
305
+ if contrast is not None:
306
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
307
+ random.shuffle(img_transforms)
308
+
309
+ # Apply to all videos
310
+ jittered_clip = []
311
+ for img in clip:
312
+ for func in img_transforms:
313
+ jittered_img = func(img)
314
+ jittered_clip.append(jittered_img)
315
+
316
+ else:
317
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
318
+ 'but got list of {0}'.format(type(clip[0])))
319
+ return jittered_clip
320
+
321
+
322
+ class AllAugmentationTransform:
323
+ def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
324
+ self.transforms = []
325
+
326
+ if flip_param is not None:
327
+ self.transforms.append(RandomFlip(**flip_param))
328
+
329
+ if rotation_param is not None:
330
+ self.transforms.append(RandomRotation(**rotation_param))
331
+
332
+ if resize_param is not None:
333
+ self.transforms.append(RandomResize(**resize_param))
334
+
335
+ if crop_param is not None:
336
+ self.transforms.append(RandomCrop(**crop_param))
337
+
338
+ if jitter_param is not None:
339
+ self.transforms.append(ColorJitter(**jitter_param))
340
+
341
+ def __call__(self, clip):
342
+ for t in self.transforms:
343
+ clip = t(clip)
344
+ return clip
TPSMM/cog.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ cuda: "11.0"
3
+ gpu: true
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ - "ninja-build"
9
+ python_packages:
10
+ - "ipython==7.21.0"
11
+ - "torch==1.10.1"
12
+ - "torchvision==0.11.2"
13
+ - "cffi==1.14.6"
14
+ - "cycler==0.10.0"
15
+ - "decorator==5.1.0"
16
+ - "face-alignment==1.3.5"
17
+ - "imageio==2.9.0"
18
+ - "imageio-ffmpeg==0.4.5"
19
+ - "kiwisolver==1.3.2"
20
+ - "matplotlib==3.4.3"
21
+ - "networkx==2.6.3"
22
+ - "numpy==1.20.3"
23
+ - "pandas==1.3.3"
24
+ - "Pillow==8.3.2"
25
+ - "pycparser==2.20"
26
+ - "pyparsing==2.4.7"
27
+ - "python-dateutil==2.8.2"
28
+ - "pytz==2021.1"
29
+ - "PyWavelets==1.1.1"
30
+ - "PyYAML==5.4.1"
31
+ - "scikit-image==0.18.3"
32
+ - "scikit-learn==1.0"
33
+ - "scipy==1.7.1"
34
+ - "six==1.16.0"
35
+ - "tqdm==4.62.3"
36
+ - "cmake==3.21.3"
37
+ run:
38
+ - pip install dlib
39
+
40
+ predict: "predict.py:Predictor"
TPSMM/config/mgif-256.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../moving-gif
3
+ frame_shape: null
4
+ id_sampling: False
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ crop_param:
10
+ size: [256, 256]
11
+ resize_param:
12
+ ratio: [0.9, 1.1]
13
+ jitter_param:
14
+ hue: 0.5
15
+
16
+ model_params:
17
+ common_params:
18
+ num_tps: 10
19
+ num_channels: 3
20
+ bg: False
21
+ multi_mask: True
22
+ generator_params:
23
+ block_expansion: 64
24
+ max_features: 512
25
+ num_down_blocks: 3
26
+ dense_motion_params:
27
+ block_expansion: 64
28
+ max_features: 1024
29
+ num_blocks: 5
30
+ scale_factor: 0.25
31
+ avd_network_params:
32
+ id_bottle_size: 128
33
+ pose_bottle_size: 128
34
+
35
+
36
+ train_params:
37
+ num_epochs: 100
38
+ num_repeats: 50
39
+ epoch_milestones: [70, 90]
40
+ lr_generator: 2.0e-4
41
+ batch_size: 28
42
+ scales: [1, 0.5, 0.25, 0.125]
43
+ dataloader_workers: 12
44
+ checkpoint_freq: 50
45
+ dropout_epoch: 35
46
+ dropout_maxp: 0.5
47
+ dropout_startp: 0.2
48
+ dropout_inc_epoch: 10
49
+ bg_start: 0
50
+ transform_params:
51
+ sigma_affine: 0.05
52
+ sigma_tps: 0.005
53
+ points_tps: 5
54
+ loss_weights:
55
+ perceptual: [10, 10, 10, 10, 10]
56
+ equivariance_value: 10
57
+ warp_loss: 10
58
+ bg: 10
59
+
60
+ train_avd_params:
61
+ num_epochs: 100
62
+ num_repeats: 50
63
+ batch_size: 256
64
+ dataloader_workers: 24
65
+ checkpoint_freq: 10
66
+ epoch_milestones: [70, 90]
67
+ lr: 1.0e-3
68
+ lambda_shift: 1
69
+ lambda_affine: 1
70
+ random_scale: 0.25
71
+
72
+ visualizer_params:
73
+ kp_size: 5
74
+ draw_border: True
75
+ colormap: 'gist_rainbow'
TPSMM/config/taichi-256.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset parameters
2
+ # Each dataset should contain 2 folders train and test
3
+ # Each video can be represented as:
4
+ # - an image of concatenated frames
5
+ # - '.mp4' or '.gif'
6
+ # - folder with all frames from a specific video
7
+ # In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
8
+ # format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
9
+ # video id.
10
+ dataset_params:
11
+ # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
12
+ root_dir: ../taichi
13
+ # Image shape, needed for staked .png format.
14
+ frame_shape: null
15
+ # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
16
+ # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
17
+ # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
18
+ id_sampling: True
19
+ # Augmentation parameters see augmentation.py for all posible augmentations
20
+ augmentation_params:
21
+ flip_param:
22
+ horizontal_flip: True
23
+ time_flip: True
24
+ jitter_param:
25
+ brightness: 0.1
26
+ contrast: 0.1
27
+ saturation: 0.1
28
+ hue: 0.1
29
+
30
+ # Defines model architecture
31
+ model_params:
32
+ common_params:
33
+ # Number of TPS transformation
34
+ num_tps: 10
35
+ # Number of channels per image
36
+ num_channels: 3
37
+ # Whether to estimate affine background transformation
38
+ bg: True
39
+ # Whether to estimate the multi-resolution occlusion masks
40
+ multi_mask: True
41
+ generator_params:
42
+ # Number of features mutliplier
43
+ block_expansion: 64
44
+ # Maximum allowed number of features
45
+ max_features: 512
46
+ # Number of downsampling blocks and Upsampling blocks.
47
+ num_down_blocks: 3
48
+ dense_motion_params:
49
+ # Number of features mutliplier
50
+ block_expansion: 64
51
+ # Maximum allowed number of features
52
+ max_features: 1024
53
+ # Number of block in Unet.
54
+ num_blocks: 5
55
+ # Optical flow is predicted on smaller images for better performance,
56
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
57
+ scale_factor: 0.25
58
+ avd_network_params:
59
+ # Bottleneck for identity branch
60
+ id_bottle_size: 128
61
+ # Bottleneck for pose branch
62
+ pose_bottle_size: 128
63
+
64
+ # Parameters of training
65
+ train_params:
66
+ # Number of training epochs
67
+ num_epochs: 100
68
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
69
+ # Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
70
+ num_repeats: 150
71
+ # Drop learning rate by 10 times after this epochs
72
+ epoch_milestones: [70, 90]
73
+ # Initial learing rate for all modules
74
+ lr_generator: 2.0e-4
75
+ batch_size: 28
76
+ # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
77
+ # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
78
+ scales: [1, 0.5, 0.25, 0.125]
79
+ # Dataset preprocessing cpu workers
80
+ dataloader_workers: 12
81
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
82
+ checkpoint_freq: 50
83
+ # Parameters of dropout
84
+ # The first dropout_epoch training uses dropout operation
85
+ dropout_epoch: 35
86
+ # The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs
87
+ dropout_maxp: 0.7
88
+ dropout_startp: 0.0
89
+ dropout_inc_epoch: 10
90
+ # Estimate affine background transformation from the bg_start epoch.
91
+ bg_start: 0
92
+ # Parameters of random TPS transformation for equivariance loss
93
+ transform_params:
94
+ # Sigma for affine part
95
+ sigma_affine: 0.05
96
+ # Sigma for deformation part
97
+ sigma_tps: 0.005
98
+ # Number of point in the deformation grid
99
+ points_tps: 5
100
+ loss_weights:
101
+ # Weights for perceptual loss.
102
+ perceptual: [10, 10, 10, 10, 10]
103
+ # Weights for value equivariance.
104
+ equivariance_value: 10
105
+ # Weights for warp loss.
106
+ warp_loss: 10
107
+ # Weights for bg loss.
108
+ bg: 10
109
+
110
+ # Parameters of training (animation-via-disentanglement)
111
+ train_avd_params:
112
+ # Number of training epochs, visualization is produced after each epoch.
113
+ num_epochs: 100
114
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
115
+ # Thus effectively with num_repeats=100 each epoch is 100 times larger.
116
+ num_repeats: 150
117
+ # Batch size.
118
+ batch_size: 256
119
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
120
+ checkpoint_freq: 10
121
+ # Dataset preprocessing cpu workers
122
+ dataloader_workers: 24
123
+ # Drop learning rate 10 times after this epochs
124
+ epoch_milestones: [70, 90]
125
+ # Initial learning rate
126
+ lr: 1.0e-3
127
+ # Weights for equivariance loss.
128
+ lambda_shift: 1
129
+ random_scale: 0.25
130
+
131
+ visualizer_params:
132
+ kp_size: 5
133
+ draw_border: True
134
+ colormap: 'gist_rainbow'
TPSMM/config/ted-384.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../TED384-v2
3
+ frame_shape: null
4
+ id_sampling: True
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ brightness: 0.1
11
+ contrast: 0.1
12
+ saturation: 0.1
13
+ hue: 0.1
14
+
15
+ model_params:
16
+ common_params:
17
+ num_tps: 10
18
+ num_channels: 3
19
+ bg: True
20
+ multi_mask: True
21
+ generator_params:
22
+ block_expansion: 64
23
+ max_features: 512
24
+ num_down_blocks: 3
25
+ dense_motion_params:
26
+ block_expansion: 64
27
+ max_features: 1024
28
+ num_blocks: 5
29
+ scale_factor: 0.25
30
+ avd_network_params:
31
+ id_bottle_size: 128
32
+ pose_bottle_size: 128
33
+
34
+
35
+ train_params:
36
+ num_epochs: 100
37
+ num_repeats: 150
38
+ epoch_milestones: [70, 90]
39
+ lr_generator: 2.0e-4
40
+ batch_size: 12
41
+ scales: [1, 0.5, 0.25, 0.125]
42
+ dataloader_workers: 6
43
+ checkpoint_freq: 50
44
+ dropout_epoch: 35
45
+ dropout_maxp: 0.5
46
+ dropout_startp: 0.0
47
+ dropout_inc_epoch: 10
48
+ bg_start: 0
49
+ transform_params:
50
+ sigma_affine: 0.05
51
+ sigma_tps: 0.005
52
+ points_tps: 5
53
+ loss_weights:
54
+ perceptual: [10, 10, 10, 10, 10]
55
+ equivariance_value: 10
56
+ warp_loss: 10
57
+ bg: 10
58
+
59
+ train_avd_params:
60
+ num_epochs: 30
61
+ num_repeats: 500
62
+ batch_size: 256
63
+ dataloader_workers: 24
64
+ checkpoint_freq: 10
65
+ epoch_milestones: [20, 25]
66
+ lr: 1.0e-3
67
+ lambda_shift: 1
68
+ random_scale: 0.25
69
+
70
+ visualizer_params:
71
+ kp_size: 5
72
+ draw_border: True
73
+ colormap: 'gist_rainbow'
TPSMM/config/vox-256.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../vox
3
+ frame_shape: null
4
+ id_sampling: True
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ brightness: 0.1
11
+ contrast: 0.1
12
+ saturation: 0.1
13
+ hue: 0.1
14
+
15
+
16
+ model_params:
17
+ common_params:
18
+ num_tps: 10
19
+ num_channels: 3
20
+ bg: True
21
+ multi_mask: True
22
+ generator_params:
23
+ block_expansion: 64
24
+ max_features: 512
25
+ num_down_blocks: 3
26
+ dense_motion_params:
27
+ block_expansion: 64
28
+ max_features: 1024
29
+ num_blocks: 5
30
+ scale_factor: 0.25
31
+ avd_network_params:
32
+ id_bottle_size: 128
33
+ pose_bottle_size: 128
34
+
35
+
36
+ train_params:
37
+ num_epochs: 100
38
+ num_repeats: 75
39
+ epoch_milestones: [70, 90]
40
+ lr_generator: 2.0e-4
41
+ batch_size: 28
42
+ scales: [1, 0.5, 0.25, 0.125]
43
+ dataloader_workers: 12
44
+ checkpoint_freq: 50
45
+ dropout_epoch: 35
46
+ dropout_maxp: 0.3
47
+ dropout_startp: 0.1
48
+ dropout_inc_epoch: 10
49
+ bg_start: 10
50
+ transform_params:
51
+ sigma_affine: 0.05
52
+ sigma_tps: 0.005
53
+ points_tps: 5
54
+ loss_weights:
55
+ perceptual: [10, 10, 10, 10, 10]
56
+ equivariance_value: 10
57
+ warp_loss: 10
58
+ bg: 10
59
+
60
+ train_avd_params:
61
+ num_epochs: 200
62
+ num_repeats: 300
63
+ batch_size: 256
64
+ dataloader_workers: 24
65
+ checkpoint_freq: 50
66
+ epoch_milestones: [140, 180]
67
+ lr: 1.0e-3
68
+ lambda_shift: 1
69
+ random_scale: 0.25
70
+
71
+ visualizer_params:
72
+ kp_size: 5
73
+ draw_border: True
74
+ colormap: 'gist_rainbow'
TPSMM/demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
TPSMM/demo.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import sys
4
+ import yaml
5
+ from argparse import ArgumentParser
6
+ from tqdm import tqdm
7
+ from scipy.spatial import ConvexHull
8
+ import numpy as np
9
+ import imageio
10
+ from skimage.transform import resize
11
+ from skimage import img_as_ubyte
12
+ import torch
13
+ from modules.inpainting_network import InpaintingNetwork
14
+ from modules.keypoint_detector import KPDetector
15
+ from modules.dense_motion import DenseMotionNetwork
16
+ from modules.avd_network import AVDNetwork
17
+
18
+ if sys.version_info[0] < 3:
19
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
20
+
21
+ def relative_kp(kp_source, kp_driving, kp_driving_initial):
22
+
23
+ source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
24
+ driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
25
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
26
+
27
+ kp_new = {k: v for k, v in kp_driving.items()}
28
+
29
+ kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
30
+ kp_value_diff *= adapt_movement_scale
31
+ kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
32
+
33
+ return kp_new
34
+
35
+ def load_checkpoints(config_path, checkpoint_path, device):
36
+ with open(config_path) as f:
37
+ config = yaml.full_load(f)
38
+
39
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
40
+ **config['model_params']['common_params'])
41
+ kp_detector = KPDetector(**config['model_params']['common_params'])
42
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
43
+ **config['model_params']['dense_motion_params'])
44
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
45
+ **config['model_params']['avd_network_params'])
46
+ kp_detector.to(device)
47
+ dense_motion_network.to(device)
48
+ inpainting.to(device)
49
+ avd_network.to(device)
50
+
51
+ checkpoint = torch.load(checkpoint_path, map_location=device)
52
+
53
+ inpainting.load_state_dict(checkpoint['inpainting_network'])
54
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
55
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
56
+ if 'avd_network' in checkpoint:
57
+ avd_network.load_state_dict(checkpoint['avd_network'])
58
+
59
+ inpainting.eval()
60
+ kp_detector.eval()
61
+ dense_motion_network.eval()
62
+ avd_network.eval()
63
+
64
+ return inpainting, kp_detector, dense_motion_network, avd_network
65
+
66
+
67
+ def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
68
+ assert mode in ['standard', 'relative', 'avd']
69
+ with torch.no_grad():
70
+ predictions = []
71
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
72
+ source = source.to(device)
73
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
74
+ kp_source = kp_detector(source)
75
+ kp_driving_initial = kp_detector(driving[:, :, 0])
76
+
77
+ for frame_idx in tqdm(range(driving.shape[2])):
78
+ driving_frame = driving[:, :, frame_idx]
79
+ driving_frame = driving_frame.to(device)
80
+ kp_driving = kp_detector(driving_frame)
81
+ if mode == 'standard':
82
+ kp_norm = kp_driving
83
+ elif mode=='relative':
84
+ kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
85
+ kp_driving_initial=kp_driving_initial)
86
+ elif mode == 'avd':
87
+ kp_norm = avd_network(kp_source, kp_driving)
88
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
89
+ kp_source=kp_source, bg_param = None,
90
+ dropout_flag = False)
91
+ out = inpainting_network(source, dense_motion)
92
+
93
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
94
+ return predictions
95
+
96
+
97
+ def find_best_frame(source, driving, cpu):
98
+ import face_alignment
99
+
100
+ def normalize_kp(kp):
101
+ kp = kp - kp.mean(axis=0, keepdims=True)
102
+ area = ConvexHull(kp[:, :2]).volume
103
+ area = np.sqrt(area)
104
+ kp[:, :2] = kp[:, :2] / area
105
+ return kp
106
+
107
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
108
+ device= 'cpu' if cpu else 'cuda')
109
+ kp_source = fa.get_landmarks(255 * source)[0]
110
+ kp_source = normalize_kp(kp_source)
111
+ norm = float('inf')
112
+ frame_num = 0
113
+ for i, image in tqdm(enumerate(driving)):
114
+ try:
115
+ kp_driving = fa.get_landmarks(255 * image)[0]
116
+ kp_driving = normalize_kp(kp_driving)
117
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
118
+ if new_norm < norm:
119
+ norm = new_norm
120
+ frame_num = i
121
+ except:
122
+ pass
123
+ return frame_num
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = ArgumentParser()
128
+ parser.add_argument("--config", required=True, help="path to config")
129
+ parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore")
130
+
131
+ parser.add_argument("--source_image", default='./assets/source.png', help="path to source image")
132
+ parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video")
133
+ parser.add_argument("--result_video", default='./result.mp4', help="path to output")
134
+
135
+ parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))),
136
+ help='Shape of image, that the model was trained on.')
137
+
138
+ parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
139
+
140
+ parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
141
+ help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
142
+
143
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
144
+
145
+ opt = parser.parse_args()
146
+
147
+ source_image = imageio.imread(opt.source_image)
148
+ reader = imageio.get_reader(opt.driving_video)
149
+ fps = reader.get_meta_data()['fps']
150
+ driving_video = []
151
+ try:
152
+ for im in reader:
153
+ driving_video.append(im)
154
+ except RuntimeError:
155
+ pass
156
+ reader.close()
157
+
158
+ if opt.cpu:
159
+ device = torch.device('cpu')
160
+ else:
161
+ device = torch.device('cuda')
162
+
163
+ source_image = resize(source_image, opt.img_shape)[..., :3]
164
+ driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video]
165
+ inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
166
+
167
+ if opt.find_best_frame:
168
+ i = find_best_frame(source_image, driving_video, opt.cpu)
169
+ print ("Best frame: " + str(i))
170
+ driving_forward = driving_video[i:]
171
+ driving_backward = driving_video[:(i+1)][::-1]
172
+ predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
173
+ predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
174
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
175
+ else:
176
+ predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
177
+
178
+ predictions =[img_as_ubyte(frame) for frame in predictions]
179
+ imageio.mimsave(opt.result_video, predictions, fps=fps)
180
+
TPSMM/frames_dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, img_as_float32
3
+ from skimage.color import gray2rgb
4
+ from sklearn.model_selection import train_test_split
5
+ from imageio import mimread
6
+ from skimage.transform import resize
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ from augmentation import AllAugmentationTransform
10
+ import glob
11
+ from functools import partial
12
+
13
+
14
+ def read_video(name, frame_shape):
15
+ """
16
+ Read video which can be:
17
+ - an image of concatenated frames
18
+ - '.mp4' and'.gif'
19
+ - folder with videos
20
+ """
21
+
22
+ if os.path.isdir(name):
23
+ frames = sorted(os.listdir(name))
24
+ num_frames = len(frames)
25
+ video_array = np.array(
26
+ [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
27
+ elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
28
+ image = io.imread(name)
29
+
30
+ if len(image.shape) == 2 or image.shape[2] == 1:
31
+ image = gray2rgb(image)
32
+
33
+ if image.shape[2] == 4:
34
+ image = image[..., :3]
35
+
36
+ image = img_as_float32(image)
37
+
38
+ video_array = np.moveaxis(image, 1, 0)
39
+
40
+ video_array = video_array.reshape((-1,) + frame_shape)
41
+ video_array = np.moveaxis(video_array, 1, 2)
42
+ elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
43
+ video = mimread(name)
44
+ if len(video[0].shape) == 2:
45
+ video = [gray2rgb(frame) for frame in video]
46
+ if frame_shape is not None:
47
+ video = np.array([resize(frame, frame_shape) for frame in video])
48
+ video = np.array(video)
49
+ if video.shape[-1] == 4:
50
+ video = video[..., :3]
51
+ video_array = img_as_float32(video)
52
+ else:
53
+ raise Exception("Unknown file extensions %s" % name)
54
+
55
+ return video_array
56
+
57
+
58
+ class FramesDataset(Dataset):
59
+ """
60
+ Dataset of videos, each video can be represented as:
61
+ - an image of concatenated frames
62
+ - '.mp4' or '.gif'
63
+ - folder with all frames
64
+ """
65
+
66
+ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
67
+ random_seed=0, pairs_list=None, augmentation_params=None):
68
+ self.root_dir = root_dir
69
+ self.videos = os.listdir(root_dir)
70
+ self.frame_shape = frame_shape
71
+ print(self.frame_shape)
72
+ self.pairs_list = pairs_list
73
+ self.id_sampling = id_sampling
74
+
75
+ if os.path.exists(os.path.join(root_dir, 'train')):
76
+ assert os.path.exists(os.path.join(root_dir, 'test'))
77
+ print("Use predefined train-test split.")
78
+ if id_sampling:
79
+ train_videos = {os.path.basename(video).split('#')[0] for video in
80
+ os.listdir(os.path.join(root_dir, 'train'))}
81
+ train_videos = list(train_videos)
82
+ else:
83
+ train_videos = os.listdir(os.path.join(root_dir, 'train'))
84
+ test_videos = os.listdir(os.path.join(root_dir, 'test'))
85
+ self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
86
+ else:
87
+ print("Use random train-test split.")
88
+ train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
89
+
90
+ if is_train:
91
+ self.videos = train_videos
92
+ else:
93
+ self.videos = test_videos
94
+
95
+ self.is_train = is_train
96
+
97
+ if self.is_train:
98
+ self.transform = AllAugmentationTransform(**augmentation_params)
99
+ else:
100
+ self.transform = None
101
+
102
+ def __len__(self):
103
+ return len(self.videos)
104
+
105
+ def __getitem__(self, idx):
106
+
107
+ if self.is_train and self.id_sampling:
108
+ name = self.videos[idx]
109
+ path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
110
+ else:
111
+ name = self.videos[idx]
112
+ path = os.path.join(self.root_dir, name)
113
+
114
+ video_name = os.path.basename(path)
115
+ if self.is_train and os.path.isdir(path):
116
+
117
+ frames = os.listdir(path)
118
+ num_frames = len(frames)
119
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
120
+
121
+ if self.frame_shape is not None:
122
+ resize_fn = partial(resize, output_shape=self.frame_shape)
123
+ else:
124
+ resize_fn = img_as_float32
125
+
126
+ if type(frames[0]) is bytes:
127
+ video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in
128
+ frame_idx]
129
+ else:
130
+ video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
131
+ else:
132
+
133
+ video_array = read_video(path, frame_shape=self.frame_shape)
134
+
135
+ num_frames = len(video_array)
136
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
137
+ num_frames)
138
+ video_array = video_array[frame_idx]
139
+
140
+
141
+ if self.transform is not None:
142
+ video_array = self.transform(video_array)
143
+
144
+ out = {}
145
+ if self.is_train:
146
+ source = np.array(video_array[0], dtype='float32')
147
+ driving = np.array(video_array[1], dtype='float32')
148
+
149
+ out['driving'] = driving.transpose((2, 0, 1))
150
+ out['source'] = source.transpose((2, 0, 1))
151
+ else:
152
+ video = np.array(video_array, dtype='float32')
153
+ out['video'] = video.transpose((3, 0, 1, 2))
154
+
155
+ out['name'] = video_name
156
+ return out
157
+
158
+
159
+ class DatasetRepeater(Dataset):
160
+ """
161
+ Pass several times over the same dataset for better i/o performance
162
+ """
163
+
164
+ def __init__(self, dataset, num_repeats=100):
165
+ self.dataset = dataset
166
+ self.num_repeats = num_repeats
167
+
168
+ def __len__(self):
169
+ return self.num_repeats * self.dataset.__len__()
170
+
171
+ def __getitem__(self, idx):
172
+ return self.dataset[idx % self.dataset.__len__()]
173
+
TPSMM/logger.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import imageio
5
+
6
+ import os
7
+ from skimage.draw import circle
8
+
9
+ import matplotlib.pyplot as plt
10
+ import collections
11
+
12
+
13
+ class Logger:
14
+ def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
15
+
16
+ self.loss_list = []
17
+ self.cpk_dir = log_dir
18
+ self.visualizations_dir = os.path.join(log_dir, 'train-vis')
19
+ if not os.path.exists(self.visualizations_dir):
20
+ os.makedirs(self.visualizations_dir)
21
+ self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
22
+ self.zfill_num = zfill_num
23
+ self.visualizer = Visualizer(**visualizer_params)
24
+ self.checkpoint_freq = checkpoint_freq
25
+ self.epoch = 0
26
+ self.best_loss = float('inf')
27
+ self.names = None
28
+
29
+ def log_scores(self, loss_names):
30
+ loss_mean = np.array(self.loss_list).mean(axis=0)
31
+
32
+ loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
33
+ loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
34
+
35
+ print(loss_string, file=self.log_file)
36
+ self.loss_list = []
37
+ self.log_file.flush()
38
+
39
+ def visualize_rec(self, inp, out):
40
+ image = self.visualizer.visualize(inp['driving'], inp['source'], out)
41
+ imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
42
+
43
+ def save_cpk(self, emergent=False):
44
+ cpk = {k: v.state_dict() for k, v in self.models.items()}
45
+ cpk['epoch'] = self.epoch
46
+ cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
47
+ if not (os.path.exists(cpk_path) and emergent):
48
+ torch.save(cpk, cpk_path)
49
+
50
+ @staticmethod
51
+ def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None,
52
+ bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
53
+ optimizer_avd=None):
54
+ checkpoint = torch.load(checkpoint_path)
55
+ if inpainting_network is not None:
56
+ inpainting_network.load_state_dict(checkpoint['inpainting_network'])
57
+ if kp_detector is not None:
58
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
59
+ if bg_predictor is not None and 'bg_predictor' in checkpoint:
60
+ bg_predictor.load_state_dict(checkpoint['bg_predictor'])
61
+ if dense_motion_network is not None:
62
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
63
+ if avd_network is not None:
64
+ if 'avd_network' in checkpoint:
65
+ avd_network.load_state_dict(checkpoint['avd_network'])
66
+ if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint:
67
+ optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor'])
68
+ if optimizer is not None and 'optimizer' in checkpoint:
69
+ optimizer.load_state_dict(checkpoint['optimizer'])
70
+ if optimizer_avd is not None:
71
+ if 'optimizer_avd' in checkpoint:
72
+ optimizer_avd.load_state_dict(checkpoint['optimizer_avd'])
73
+ epoch = -1
74
+ if 'epoch' in checkpoint:
75
+ epoch = checkpoint['epoch']
76
+ return epoch
77
+
78
+ def __enter__(self):
79
+ return self
80
+
81
+ def __exit__(self, exc_type, exc_value, tb):
82
+ if 'models' in self.__dict__:
83
+ self.save_cpk()
84
+ self.log_file.close()
85
+
86
+ def log_iter(self, losses):
87
+ losses = collections.OrderedDict(losses.items())
88
+ self.names = list(losses.keys())
89
+ self.loss_list.append(list(losses.values()))
90
+
91
+ def log_epoch(self, epoch, models, inp, out):
92
+ self.epoch = epoch
93
+ self.models = models
94
+ if (self.epoch + 1) % self.checkpoint_freq == 0:
95
+ self.save_cpk()
96
+ self.log_scores(self.names)
97
+ self.visualize_rec(inp, out)
98
+
99
+
100
+ class Visualizer:
101
+ def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
102
+ self.kp_size = kp_size
103
+ self.draw_border = draw_border
104
+ self.colormap = plt.get_cmap(colormap)
105
+
106
+ def draw_image_with_kp(self, image, kp_array):
107
+ image = np.copy(image)
108
+ spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
109
+ kp_array = spatial_size * (kp_array + 1) / 2
110
+ num_kp = kp_array.shape[0]
111
+ for kp_ind, kp in enumerate(kp_array):
112
+ rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
113
+ image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
114
+ return image
115
+
116
+ def create_image_column_with_kp(self, images, kp):
117
+ image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
118
+ return self.create_image_column(image_array)
119
+
120
+ def create_image_column(self, images):
121
+ if self.draw_border:
122
+ images = np.copy(images)
123
+ images[:, :, [0, -1]] = (1, 1, 1)
124
+ images[:, :, [0, -1]] = (1, 1, 1)
125
+ return np.concatenate(list(images), axis=0)
126
+
127
+ def create_image_grid(self, *args):
128
+ out = []
129
+ for arg in args:
130
+ if type(arg) == tuple:
131
+ out.append(self.create_image_column_with_kp(arg[0], arg[1]))
132
+ else:
133
+ out.append(self.create_image_column(arg))
134
+ return np.concatenate(out, axis=1)
135
+
136
+ def visualize(self, driving, source, out):
137
+ images = []
138
+
139
+ # Source image with keypoints
140
+ source = source.data.cpu()
141
+ kp_source = out['kp_source']['fg_kp'].data.cpu().numpy()
142
+ source = np.transpose(source, [0, 2, 3, 1])
143
+ images.append((source, kp_source))
144
+
145
+ # Equivariance visualization
146
+ if 'transformed_frame' in out:
147
+ transformed = out['transformed_frame'].data.cpu().numpy()
148
+ transformed = np.transpose(transformed, [0, 2, 3, 1])
149
+ transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy()
150
+ images.append((transformed, transformed_kp))
151
+
152
+ # Driving image with keypoints
153
+ kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy()
154
+ driving = driving.data.cpu().numpy()
155
+ driving = np.transpose(driving, [0, 2, 3, 1])
156
+ images.append((driving, kp_driving))
157
+
158
+ # Deformed image
159
+ if 'deformed' in out:
160
+ deformed = out['deformed'].data.cpu().numpy()
161
+ deformed = np.transpose(deformed, [0, 2, 3, 1])
162
+ images.append(deformed)
163
+
164
+ # Result with and without keypoints
165
+ prediction = out['prediction'].data.cpu().numpy()
166
+ prediction = np.transpose(prediction, [0, 2, 3, 1])
167
+ if 'kp_norm' in out:
168
+ kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy()
169
+ images.append((prediction, kp_norm))
170
+ images.append(prediction)
171
+
172
+
173
+ ## Occlusion map
174
+ if 'occlusion_map' in out:
175
+ for i in range(len(out['occlusion_map'])):
176
+ occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1)
177
+ occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
178
+ occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
179
+ images.append(occlusion_map)
180
+
181
+ # Deformed images according to each individual transform
182
+ if 'deformed_source' in out:
183
+ full_mask = []
184
+ for i in range(out['deformed_source'].shape[1]):
185
+ image = out['deformed_source'][:, i].data.cpu()
186
+ # import ipdb;ipdb.set_trace()
187
+ image = F.interpolate(image, size=source.shape[1:3])
188
+ mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
189
+ mask = F.interpolate(mask, size=source.shape[1:3])
190
+ image = np.transpose(image.numpy(), (0, 2, 3, 1))
191
+ mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
192
+
193
+ if i != 0:
194
+ color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3]
195
+ else:
196
+ color = np.array((0, 0, 0))
197
+
198
+ color = color.reshape((1, 1, 1, 3))
199
+
200
+ images.append(image)
201
+ if i != 0:
202
+ images.append(mask * color)
203
+ else:
204
+ images.append(mask)
205
+
206
+ full_mask.append(mask * color)
207
+
208
+ images.append(sum(full_mask))
209
+
210
+ image = self.create_image_grid(*images)
211
+ image = (255 * image).astype(np.uint8)
212
+ return image
TPSMM/modules/avd_network.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class AVDNetwork(nn.Module):
7
+ """
8
+ Animation via Disentanglement network
9
+ """
10
+
11
+ def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
12
+ super(AVDNetwork, self).__init__()
13
+ input_size = 5*2 * num_tps
14
+ self.num_tps = num_tps
15
+
16
+ self.id_encoder = nn.Sequential(
17
+ nn.Linear(input_size, 256),
18
+ nn.BatchNorm1d(256),
19
+ nn.ReLU(inplace=True),
20
+ nn.Linear(256, 512),
21
+ nn.BatchNorm1d(512),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(512, 1024),
24
+ nn.BatchNorm1d(1024),
25
+ nn.ReLU(inplace=True),
26
+ nn.Linear(1024, id_bottle_size)
27
+ )
28
+
29
+ self.pose_encoder = nn.Sequential(
30
+ nn.Linear(input_size, 256),
31
+ nn.BatchNorm1d(256),
32
+ nn.ReLU(inplace=True),
33
+ nn.Linear(256, 512),
34
+ nn.BatchNorm1d(512),
35
+ nn.ReLU(inplace=True),
36
+ nn.Linear(512, 1024),
37
+ nn.BatchNorm1d(1024),
38
+ nn.ReLU(inplace=True),
39
+ nn.Linear(1024, pose_bottle_size)
40
+ )
41
+
42
+ self.decoder = nn.Sequential(
43
+ nn.Linear(pose_bottle_size + id_bottle_size, 1024),
44
+ nn.BatchNorm1d(1024),
45
+ nn.ReLU(),
46
+ nn.Linear(1024, 512),
47
+ nn.BatchNorm1d(512),
48
+ nn.ReLU(),
49
+ nn.Linear(512, 256),
50
+ nn.BatchNorm1d(256),
51
+ nn.ReLU(),
52
+ nn.Linear(256, input_size)
53
+ )
54
+
55
+ def forward(self, kp_source, kp_random):
56
+
57
+ bs = kp_source['fg_kp'].shape[0]
58
+
59
+ pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
60
+ id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
61
+
62
+ rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
63
+
64
+ rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
65
+ return rec
TPSMM/modules/bg_motion_predictor.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from torchvision import models
4
+
5
+ class BGMotionPredictor(nn.Module):
6
+ """
7
+ Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
8
+ """
9
+
10
+ def __init__(self):
11
+ super(BGMotionPredictor, self).__init__()
12
+ self.bg_encoder = models.resnet18(pretrained=False)
13
+ self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
14
+ num_features = self.bg_encoder.fc.in_features
15
+ self.bg_encoder.fc = nn.Linear(num_features, 6)
16
+ self.bg_encoder.fc.weight.data.zero_()
17
+ self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
18
+
19
+ def forward(self, source_image, driving_image):
20
+ bs = source_image.shape[0]
21
+ out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
22
+ prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
23
+ out[:, :2, :] = prediction.view(bs, 2, 3)
24
+ return out
TPSMM/modules/dense_motion.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5
+ from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
6
+ import math
7
+
8
+ class DenseMotionNetwork(nn.Module):
9
+ """
10
+ Module that estimating an optical flow and multi-resolution occlusion masks
11
+ from K TPS transformations and an affine transformation.
12
+ """
13
+
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
15
+ scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
16
+ super(DenseMotionNetwork, self).__init__()
17
+
18
+ if scale_factor != 1:
19
+ self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
20
+ self.scale_factor = scale_factor
21
+ self.multi_mask = multi_mask
22
+
23
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
24
+ max_features=max_features, num_blocks=num_blocks)
25
+
26
+ hourglass_output_size = self.hourglass.out_channels
27
+ self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
28
+
29
+ if multi_mask:
30
+ up = []
31
+ self.up_nums = int(math.log(1/scale_factor, 2))
32
+ self.occlusion_num = 4
33
+
34
+ channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
35
+ for i in range(self.up_nums):
36
+ up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
37
+ self.up = nn.ModuleList(up)
38
+
39
+ channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
40
+ for i in range(self.up_nums):
41
+ channel.append(hourglass_output_size[-1]//(2**(i+1)))
42
+ occlusion = []
43
+
44
+ for i in range(self.occlusion_num):
45
+ occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
46
+ self.occlusion = nn.ModuleList(occlusion)
47
+ else:
48
+ occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
49
+ self.occlusion = nn.ModuleList(occlusion)
50
+
51
+ self.num_tps = num_tps
52
+ self.bg = bg
53
+ self.kp_variance = kp_variance
54
+
55
+
56
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
57
+
58
+ spatial_size = source_image.shape[2:]
59
+ gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
60
+ gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
61
+ heatmap = gaussian_driving - gaussian_source
62
+
63
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
64
+ heatmap = torch.cat([zeros, heatmap], dim=1)
65
+
66
+ return heatmap
67
+
68
+ def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
69
+ # K TPS transformaions
70
+ bs, _, h, w = source_image.shape
71
+ kp_1 = kp_driving['fg_kp']
72
+ kp_2 = kp_source['fg_kp']
73
+ kp_1 = kp_1.view(bs, -1, 5, 2)
74
+ kp_2 = kp_2.view(bs, -1, 5, 2)
75
+ trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
76
+ driving_to_source = trans.transform_frame(source_image)
77
+
78
+ identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
79
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
80
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
81
+
82
+ # affine background transformation
83
+ if not (bg_param is None):
84
+ identity_grid = to_homogeneous(identity_grid)
85
+ identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
86
+ identity_grid = from_homogeneous(identity_grid)
87
+
88
+ transformations = torch.cat([identity_grid, driving_to_source], dim=1)
89
+ return transformations
90
+
91
+ def create_deformed_source_image(self, source_image, transformations):
92
+
93
+ bs, _, h, w = source_image.shape
94
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
95
+ source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
96
+ transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
97
+ deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
98
+ deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
99
+ return deformed
100
+
101
+ def dropout_softmax(self, X, P):
102
+ '''
103
+ Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
104
+ '''
105
+ drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
106
+ drop[..., 0] = 1
107
+ drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
108
+
109
+ maxx = X.max(1).values.unsqueeze_(1)
110
+ X = X - maxx
111
+ X_exp = X.exp()
112
+ X[:,1:,...] /= (1-P)
113
+ mask_bool =(drop == 0)
114
+ X_exp = X_exp.masked_fill(mask_bool, 0)
115
+ partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
116
+ return X_exp / partition
117
+
118
+ def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
119
+ if self.scale_factor != 1:
120
+ source_image = self.down(source_image)
121
+
122
+ bs, _, h, w = source_image.shape
123
+
124
+ out_dict = dict()
125
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
126
+ transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
127
+ deformed_source = self.create_deformed_source_image(source_image, transformations)
128
+ out_dict['deformed_source'] = deformed_source
129
+ # out_dict['transformations'] = transformations
130
+ deformed_source = deformed_source.view(bs,-1,h,w)
131
+ input = torch.cat([heatmap_representation, deformed_source], dim=1)
132
+ input = input.view(bs, -1, h, w)
133
+
134
+ prediction = self.hourglass(input, mode = 1)
135
+
136
+ contribution_maps = self.maps(prediction[-1])
137
+ if(dropout_flag):
138
+ contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
139
+ else:
140
+ contribution_maps = F.softmax(contribution_maps, dim=1)
141
+ out_dict['contribution_maps'] = contribution_maps
142
+
143
+ # Combine the K+1 transformations
144
+ # Eq(6) in the paper
145
+ contribution_maps = contribution_maps.unsqueeze(2)
146
+ transformations = transformations.permute(0, 1, 4, 2, 3)
147
+ deformation = (transformations * contribution_maps).sum(dim=1)
148
+ deformation = deformation.permute(0, 2, 3, 1)
149
+
150
+ out_dict['deformation'] = deformation # Optical Flow
151
+
152
+ occlusion_map = []
153
+ if self.multi_mask:
154
+ for i in range(self.occlusion_num-self.up_nums):
155
+ occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
156
+ prediction = prediction[-1]
157
+ for i in range(self.up_nums):
158
+ prediction = self.up[i](prediction)
159
+ occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
160
+ else:
161
+ occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
162
+
163
+ out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
164
+ return out_dict
TPSMM/modules/inpainting_network.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
5
+ from modules.dense_motion import DenseMotionNetwork
6
+
7
+
8
+ class InpaintingNetwork(nn.Module):
9
+ """
10
+ Inpaint the missing regions and reconstruct the Driving image.
11
+ """
12
+ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
13
+ super(InpaintingNetwork, self).__init__()
14
+
15
+ self.num_down_blocks = num_down_blocks
16
+ self.multi_mask = multi_mask
17
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
18
+
19
+ down_blocks = []
20
+ up_blocks = []
21
+ resblock = []
22
+ for i in range(num_down_blocks):
23
+ in_features = min(max_features, block_expansion * (2 ** i))
24
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
25
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
26
+ decoder_in_feature = out_features * 2
27
+ if i==num_down_blocks-1:
28
+ decoder_in_feature = out_features
29
+ up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1)))
30
+ resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
31
+ resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
32
+ self.down_blocks = nn.ModuleList(down_blocks)
33
+ self.up_blocks = nn.ModuleList(up_blocks[::-1])
34
+ self.resblock = nn.ModuleList(resblock[::-1])
35
+
36
+ self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
37
+ self.num_channels = num_channels
38
+
39
+ def deform_input(self, inp, deformation):
40
+ _, h_old, w_old, _ = deformation.shape
41
+ _, _, h, w = inp.shape
42
+ if h_old != h or w_old != w:
43
+ deformation = deformation.permute(0, 3, 1, 2)
44
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
45
+ deformation = deformation.permute(0, 2, 3, 1)
46
+ return F.grid_sample(inp, deformation,align_corners=True)
47
+
48
+ def occlude_input(self, inp, occlusion_map):
49
+ if not self.multi_mask:
50
+ if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
51
+ occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
52
+ out = inp * occlusion_map
53
+ return out
54
+
55
+ def forward(self, source_image, dense_motion):
56
+ out = self.first(source_image)
57
+ encoder_map = [out]
58
+ for i in range(len(self.down_blocks)):
59
+ out = self.down_blocks[i](out)
60
+ encoder_map.append(out)
61
+
62
+ output_dict = {}
63
+ output_dict['contribution_maps'] = dense_motion['contribution_maps']
64
+ output_dict['deformed_source'] = dense_motion['deformed_source']
65
+
66
+ occlusion_map = dense_motion['occlusion_map']
67
+ output_dict['occlusion_map'] = occlusion_map
68
+
69
+ deformation = dense_motion['deformation']
70
+ out_ij = self.deform_input(out.detach(), deformation)
71
+ out = self.deform_input(out, deformation)
72
+
73
+ out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
74
+ out = self.occlude_input(out, occlusion_map[0])
75
+
76
+ warped_encoder_maps = []
77
+ warped_encoder_maps.append(out_ij)
78
+
79
+ for i in range(self.num_down_blocks):
80
+
81
+ out = self.resblock[2*i](out)
82
+ out = self.resblock[2*i+1](out)
83
+ out = self.up_blocks[i](out)
84
+
85
+ encode_i = encoder_map[-(i+2)]
86
+ encode_ij = self.deform_input(encode_i.detach(), deformation)
87
+ encode_i = self.deform_input(encode_i, deformation)
88
+
89
+ occlusion_ind = 0
90
+ if self.multi_mask:
91
+ occlusion_ind = i+1
92
+ encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
93
+ encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
94
+ warped_encoder_maps.append(encode_ij)
95
+
96
+ if(i==self.num_down_blocks-1):
97
+ break
98
+
99
+ out = torch.cat([out, encode_i], 1)
100
+
101
+ deformed_source = self.deform_input(source_image, deformation)
102
+ output_dict["deformed"] = deformed_source
103
+ output_dict["warped_encoder_maps"] = warped_encoder_maps
104
+
105
+ occlusion_last = occlusion_map[-1]
106
+ if not self.multi_mask:
107
+ occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
108
+
109
+ out = out * (1 - occlusion_last) + encode_i
110
+ out = self.final(out)
111
+ out = torch.sigmoid(out)
112
+ out = out * (1 - occlusion_last) + deformed_source * occlusion_last
113
+ output_dict["prediction"] = out
114
+
115
+ return output_dict
116
+
117
+ def get_encode(self, driver_image, occlusion_map):
118
+ out = self.first(driver_image)
119
+ encoder_map = []
120
+ encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
121
+ for i in range(len(self.down_blocks)):
122
+ out = self.down_blocks[i](out.detach())
123
+ out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
124
+ encoder_map.append(out_mask.detach())
125
+
126
+ return encoder_map
127
+
TPSMM/modules/keypoint_detector.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from torchvision import models
4
+
5
+ class KPDetector(nn.Module):
6
+ """
7
+ Predict K*5 keypoints.
8
+ """
9
+
10
+ def __init__(self, num_tps, **kwargs):
11
+ super(KPDetector, self).__init__()
12
+ self.num_tps = num_tps
13
+
14
+ self.fg_encoder = models.resnet18(pretrained=False)
15
+ num_features = self.fg_encoder.fc.in_features
16
+ self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
17
+
18
+
19
+ def forward(self, image):
20
+
21
+ fg_kp = self.fg_encoder(image)
22
+ bs, _, = fg_kp.shape
23
+ fg_kp = torch.sigmoid(fg_kp)
24
+ fg_kp = fg_kp * 2 - 1
25
+ out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
26
+
27
+ return out
TPSMM/modules/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import AntiAliasInterpolation2d, TPS
5
+ from torchvision import models
6
+ import numpy as np
7
+
8
+
9
+ class Vgg19(torch.nn.Module):
10
+ """
11
+ Vgg19 network for perceptual loss. See Sec 3.3.
12
+ """
13
+ def __init__(self, requires_grad=False):
14
+ super(Vgg19, self).__init__()
15
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
16
+ self.slice1 = torch.nn.Sequential()
17
+ self.slice2 = torch.nn.Sequential()
18
+ self.slice3 = torch.nn.Sequential()
19
+ self.slice4 = torch.nn.Sequential()
20
+ self.slice5 = torch.nn.Sequential()
21
+ for x in range(2):
22
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
23
+ for x in range(2, 7):
24
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
25
+ for x in range(7, 12):
26
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
27
+ for x in range(12, 21):
28
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
29
+ for x in range(21, 30):
30
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
31
+
32
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
33
+ requires_grad=False)
34
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
35
+ requires_grad=False)
36
+
37
+ if not requires_grad:
38
+ for param in self.parameters():
39
+ param.requires_grad = False
40
+
41
+ def forward(self, X):
42
+ X = (X - self.mean) / self.std
43
+ h_relu1 = self.slice1(X)
44
+ h_relu2 = self.slice2(h_relu1)
45
+ h_relu3 = self.slice3(h_relu2)
46
+ h_relu4 = self.slice4(h_relu3)
47
+ h_relu5 = self.slice5(h_relu4)
48
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
49
+ return out
50
+
51
+
52
+ class ImagePyramide(torch.nn.Module):
53
+ """
54
+ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
55
+ """
56
+ def __init__(self, scales, num_channels):
57
+ super(ImagePyramide, self).__init__()
58
+ downs = {}
59
+ for scale in scales:
60
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
61
+ self.downs = nn.ModuleDict(downs)
62
+
63
+ def forward(self, x):
64
+ out_dict = {}
65
+ for scale, down_module in self.downs.items():
66
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
67
+ return out_dict
68
+
69
+
70
+ def detach_kp(kp):
71
+ return {key: value.detach() for key, value in kp.items()}
72
+
73
+
74
+ class GeneratorFullModel(torch.nn.Module):
75
+ """
76
+ Merge all generator related updates into single model for better multi-gpu usage
77
+ """
78
+
79
+ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
80
+ super(GeneratorFullModel, self).__init__()
81
+ self.kp_extractor = kp_extractor
82
+ self.inpainting_network = inpainting_network
83
+ self.dense_motion_network = dense_motion_network
84
+
85
+ self.bg_predictor = None
86
+ if bg_predictor:
87
+ self.bg_predictor = bg_predictor
88
+ self.bg_start = train_params['bg_start']
89
+
90
+ self.train_params = train_params
91
+ self.scales = train_params['scales']
92
+
93
+ self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
94
+ if torch.cuda.is_available():
95
+ self.pyramid = self.pyramid.cuda()
96
+
97
+ self.loss_weights = train_params['loss_weights']
98
+ self.dropout_epoch = train_params['dropout_epoch']
99
+ self.dropout_maxp = train_params['dropout_maxp']
100
+ self.dropout_inc_epoch = train_params['dropout_inc_epoch']
101
+ self.dropout_startp =train_params['dropout_startp']
102
+
103
+ if sum(self.loss_weights['perceptual']) != 0:
104
+ self.vgg = Vgg19()
105
+ if torch.cuda.is_available():
106
+ self.vgg = self.vgg.cuda()
107
+
108
+
109
+ def forward(self, x, epoch):
110
+ kp_source = self.kp_extractor(x['source'])
111
+ kp_driving = self.kp_extractor(x['driving'])
112
+ bg_param = None
113
+ if self.bg_predictor:
114
+ if(epoch>=self.bg_start):
115
+ bg_param = self.bg_predictor(x['source'], x['driving'])
116
+
117
+ if(epoch>=self.dropout_epoch):
118
+ dropout_flag = False
119
+ dropout_p = 0
120
+ else:
121
+ # dropout_p will linearly increase from dropout_startp to dropout_maxp
122
+ dropout_flag = True
123
+ dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
124
+
125
+ dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
126
+ kp_source=kp_source, bg_param = bg_param,
127
+ dropout_flag = dropout_flag, dropout_p = dropout_p)
128
+ generated = self.inpainting_network(x['source'], dense_motion)
129
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
130
+
131
+ loss_values = {}
132
+
133
+ pyramide_real = self.pyramid(x['driving'])
134
+ pyramide_generated = self.pyramid(generated['prediction'])
135
+
136
+ # reconstruction loss
137
+ if sum(self.loss_weights['perceptual']) != 0:
138
+ value_total = 0
139
+ for scale in self.scales:
140
+ x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
141
+ y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
142
+
143
+ for i, weight in enumerate(self.loss_weights['perceptual']):
144
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
145
+ value_total += self.loss_weights['perceptual'][i] * value
146
+ loss_values['perceptual'] = value_total
147
+
148
+ # equivariance loss
149
+ if self.loss_weights['equivariance_value'] != 0:
150
+ transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
151
+ transform_grid = transform_random.transform_frame(x['driving'])
152
+ transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
153
+ transformed_kp = self.kp_extractor(transformed_frame)
154
+
155
+ generated['transformed_frame'] = transformed_frame
156
+ generated['transformed_kp'] = transformed_kp
157
+
158
+ warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
159
+ kp_d = kp_driving['fg_kp']
160
+ value = torch.abs(kp_d - warped).mean()
161
+ loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
162
+
163
+ # warp loss
164
+ if self.loss_weights['warp_loss'] != 0:
165
+ occlusion_map = generated['occlusion_map']
166
+ encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
167
+ decode_map = generated['warped_encoder_maps']
168
+ value = 0
169
+ for i in range(len(encode_map)):
170
+ value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
171
+
172
+ loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
173
+
174
+ # bg loss
175
+ if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
176
+ bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
177
+ value = torch.matmul(bg_param, bg_param_reverse)
178
+ eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
179
+ value = torch.abs(eye - value).mean()
180
+ loss_values['bg'] = self.loss_weights['bg'] * value
181
+
182
+ return loss_values, generated
TPSMM/modules/util.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+
6
+ class TPS:
7
+ '''
8
+ TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
9
+ '''
10
+ def __init__(self, mode, bs, **kwargs):
11
+ self.bs = bs
12
+ self.mode = mode
13
+ if mode == 'random':
14
+ noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
15
+ self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
16
+ self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
17
+ self.control_points = self.control_points.unsqueeze(0)
18
+ self.control_params = torch.normal(mean=0,
19
+ std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
20
+ elif mode == 'kp':
21
+ kp_1 = kwargs["kp_1"]
22
+ kp_2 = kwargs["kp_2"]
23
+ device = kp_1.device
24
+ kp_type = kp_1.type()
25
+ self.gs = kp_1.shape[1]
26
+ n = kp_1.shape[2]
27
+ K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
28
+ K = K**2
29
+ K = K * torch.log(K+1e-9)
30
+
31
+ one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
32
+ kp_1p = torch.cat([kp_1,one1], 3)
33
+
34
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
35
+ P = torch.cat([kp_1p, zero],2)
36
+ L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
37
+ L = torch.cat([L,P],3)
38
+
39
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
40
+ Y = torch.cat([kp_2, zero], 2)
41
+ one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
42
+ L = L + one
43
+
44
+ param = torch.matmul(torch.inverse(L),Y)
45
+ self.theta = param[:,:,n:,:].permute(0,1,3,2)
46
+
47
+ self.control_points = kp_1
48
+ self.control_params = param[:,:,:n,:]
49
+ else:
50
+ raise Exception("Error TPS mode")
51
+
52
+ def transform_frame(self, frame):
53
+ grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
54
+ grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
55
+ shape = [self.bs, frame.shape[2], frame.shape[3], 2]
56
+ if self.mode == 'kp':
57
+ shape.insert(1, self.gs)
58
+ grid = self.warp_coordinates(grid).view(*shape)
59
+ return grid
60
+
61
+ def warp_coordinates(self, coordinates):
62
+ theta = self.theta.type(coordinates.type()).to(coordinates.device)
63
+ control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
64
+ control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
65
+
66
+ if self.mode == 'kp':
67
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
68
+
69
+ distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
70
+
71
+ distances = distances ** 2
72
+ result = distances.sum(-1)
73
+ result = result * torch.log(result + 1e-9)
74
+ result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
75
+ transformed = transformed.permute(0, 1, 3, 2) + result
76
+
77
+ elif self.mode == 'random':
78
+ theta = theta.unsqueeze(1)
79
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
80
+ transformed = transformed.squeeze(-1)
81
+ ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
82
+ distances = ances ** 2
83
+
84
+ result = distances.sum(-1)
85
+ result = result * torch.log(result + 1e-9)
86
+ result = result * control_params
87
+ result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
88
+ transformed = transformed + result
89
+ else:
90
+ raise Exception("Error TPS mode")
91
+
92
+ return transformed
93
+
94
+
95
+ def kp2gaussian(kp, spatial_size, kp_variance):
96
+ """
97
+ Transform a keypoint into gaussian like representation
98
+ """
99
+
100
+ coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
101
+ number_of_leading_dimensions = len(kp.shape) - 1
102
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
103
+ coordinate_grid = coordinate_grid.view(*shape)
104
+ repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
105
+ coordinate_grid = coordinate_grid.repeat(*repeats)
106
+
107
+ # Preprocess kp shape
108
+ shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
109
+ kp = kp.view(*shape)
110
+
111
+ mean_sub = (coordinate_grid - kp)
112
+
113
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
114
+
115
+ return out
116
+
117
+
118
+ def make_coordinate_grid(spatial_size, type):
119
+ """
120
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
121
+ """
122
+ h, w = spatial_size
123
+ x = torch.arange(w).type(type)
124
+ y = torch.arange(h).type(type)
125
+
126
+ x = (2 * (x / (w - 1)) - 1)
127
+ y = (2 * (y / (h - 1)) - 1)
128
+
129
+ yy = y.view(-1, 1).repeat(1, w)
130
+ xx = x.view(1, -1).repeat(h, 1)
131
+
132
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
133
+
134
+ return meshed
135
+
136
+
137
+ class ResBlock2d(nn.Module):
138
+ """
139
+ Res block, preserve spatial resolution.
140
+ """
141
+
142
+ def __init__(self, in_features, kernel_size, padding):
143
+ super(ResBlock2d, self).__init__()
144
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
145
+ padding=padding)
146
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
147
+ padding=padding)
148
+ self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
149
+ self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
150
+
151
+ def forward(self, x):
152
+ out = self.norm1(x)
153
+ out = F.relu(out)
154
+ out = self.conv1(out)
155
+ out = self.norm2(out)
156
+ out = F.relu(out)
157
+ out = self.conv2(out)
158
+ out += x
159
+ return out
160
+
161
+
162
+ class UpBlock2d(nn.Module):
163
+ """
164
+ Upsampling block for use in decoder.
165
+ """
166
+
167
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
168
+ super(UpBlock2d, self).__init__()
169
+
170
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
171
+ padding=padding, groups=groups)
172
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
173
+
174
+ def forward(self, x):
175
+ out = F.interpolate(x, scale_factor=2)
176
+ out = self.conv(out)
177
+ out = self.norm(out)
178
+ out = F.relu(out)
179
+ return out
180
+
181
+
182
+ class DownBlock2d(nn.Module):
183
+ """
184
+ Downsampling block for use in encoder.
185
+ """
186
+
187
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
188
+ super(DownBlock2d, self).__init__()
189
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
190
+ padding=padding, groups=groups)
191
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
192
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
193
+
194
+ def forward(self, x):
195
+ out = self.conv(x)
196
+ out = self.norm(out)
197
+ out = F.relu(out)
198
+ out = self.pool(out)
199
+ return out
200
+
201
+
202
+ class SameBlock2d(nn.Module):
203
+ """
204
+ Simple block, preserve spatial resolution.
205
+ """
206
+
207
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
208
+ super(SameBlock2d, self).__init__()
209
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
210
+ kernel_size=kernel_size, padding=padding, groups=groups)
211
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
212
+
213
+ def forward(self, x):
214
+ out = self.conv(x)
215
+ out = self.norm(out)
216
+ out = F.relu(out)
217
+ return out
218
+
219
+
220
+ class Encoder(nn.Module):
221
+ """
222
+ Hourglass Encoder
223
+ """
224
+
225
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
226
+ super(Encoder, self).__init__()
227
+
228
+ down_blocks = []
229
+ for i in range(num_blocks):
230
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
231
+ min(max_features, block_expansion * (2 ** (i + 1))),
232
+ kernel_size=3, padding=1))
233
+ self.down_blocks = nn.ModuleList(down_blocks)
234
+
235
+ def forward(self, x):
236
+ outs = [x]
237
+ #print('encoder:' ,outs[-1].shape)
238
+ for down_block in self.down_blocks:
239
+ outs.append(down_block(outs[-1]))
240
+ #print('encoder:' ,outs[-1].shape)
241
+ return outs
242
+
243
+
244
+ class Decoder(nn.Module):
245
+ """
246
+ Hourglass Decoder
247
+ """
248
+
249
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
250
+ super(Decoder, self).__init__()
251
+
252
+ up_blocks = []
253
+ self.out_channels = []
254
+ for i in range(num_blocks)[::-1]:
255
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
256
+ self.out_channels.append(in_filters)
257
+ out_filters = min(max_features, block_expansion * (2 ** i))
258
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
259
+
260
+ self.up_blocks = nn.ModuleList(up_blocks)
261
+ self.out_channels.append(block_expansion + in_features)
262
+ # self.out_filters = block_expansion + in_features
263
+
264
+ def forward(self, x, mode = 0):
265
+ out = x.pop()
266
+ outs = []
267
+ for up_block in self.up_blocks:
268
+ out = up_block(out)
269
+ skip = x.pop()
270
+ out = torch.cat([out, skip], dim=1)
271
+ outs.append(out)
272
+ if(mode == 0):
273
+ return out
274
+ else:
275
+ return outs
276
+
277
+
278
+ class Hourglass(nn.Module):
279
+ """
280
+ Hourglass architecture.
281
+ """
282
+
283
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
284
+ super(Hourglass, self).__init__()
285
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
286
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
287
+ self.out_channels = self.decoder.out_channels
288
+ # self.out_filters = self.decoder.out_filters
289
+
290
+ def forward(self, x, mode = 0):
291
+ return self.decoder(self.encoder(x), mode)
292
+
293
+
294
+ class AntiAliasInterpolation2d(nn.Module):
295
+ """
296
+ Band-limited downsampling, for better preservation of the input signal.
297
+ """
298
+ def __init__(self, channels, scale):
299
+ super(AntiAliasInterpolation2d, self).__init__()
300
+ sigma = (1 / scale - 1) / 2
301
+ kernel_size = 2 * round(sigma * 4) + 1
302
+ self.ka = kernel_size // 2
303
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
304
+
305
+ kernel_size = [kernel_size, kernel_size]
306
+ sigma = [sigma, sigma]
307
+ # The gaussian kernel is the product of the
308
+ # gaussian function of each dimension.
309
+ kernel = 1
310
+ meshgrids = torch.meshgrid(
311
+ [
312
+ torch.arange(size, dtype=torch.float32)
313
+ for size in kernel_size
314
+ ]
315
+ )
316
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
317
+ mean = (size - 1) / 2
318
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
319
+
320
+ # Make sure sum of values in gaussian kernel equals 1.
321
+ kernel = kernel / torch.sum(kernel)
322
+ # Reshape to depthwise convolutional weight
323
+ kernel = kernel.view(1, 1, *kernel.size())
324
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
325
+
326
+ self.register_buffer('weight', kernel)
327
+ self.groups = channels
328
+ self.scale = scale
329
+
330
+ def forward(self, input):
331
+ if self.scale == 1.0:
332
+ return input
333
+
334
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
335
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
336
+ out = F.interpolate(out, scale_factor=(self.scale, self.scale))
337
+
338
+ return out
339
+
340
+
341
+ def to_homogeneous(coordinates):
342
+ ones_shape = list(coordinates.shape)
343
+ ones_shape[-1] = 1
344
+ ones = torch.ones(ones_shape).type(coordinates.type())
345
+
346
+ return torch.cat([coordinates, ones], dim=-1)
347
+
348
+ def from_homogeneous(coordinates):
349
+ return coordinates[..., :2] / coordinates[..., 2:3]
TPSMM/pkgs/tpsmm.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from skimage import img_as_ubyte
7
+ from skimage.transform import resize
8
+ pwd = os.path.dirname(os.path.realpath(__file__))
9
+ sys.path.insert(1, os.path.join(pwd, ".."))
10
+
11
+ from demo import relative_kp, load_checkpoints
12
+
13
+
14
+ class TPSMM:
15
+ def __init__(self):
16
+ self.device = torch.device("cuda")
17
+ self.inpainting, self.kp_detector, self.dense_motion_network, self.avd_network = load_checkpoints(
18
+ config_path=os.path.join(pwd, "../config/vox-256.yaml"),
19
+ checkpoint_path=os.path.join(pwd, "../pretrained/vox.pth.tar"),
20
+ device=self.device
21
+ )
22
+ self.kp_driving_initial = None
23
+
24
+ def process_source(self, src_img):
25
+ with torch.no_grad():
26
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
27
+ src_img = resize(src_img, (256, 256))
28
+ source_tensor = torch.tensor(src_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device)
29
+ kp_source = self.kp_detector(source_tensor)
30
+
31
+ return source_tensor, kp_source
32
+
33
+
34
+ def gen_image(self, driving_img, source_tensor, kp_source):
35
+ with torch.no_grad():
36
+ driving_img = cv2.cvtColor(driving_img, cv2.COLOR_BGR2RGB)
37
+ driving_img = resize(driving_img, (256, 256))
38
+ driving_frame = torch.tensor(driving_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device)
39
+
40
+ kp_driving = self.kp_detector(driving_frame)
41
+ if self.kp_driving_initial is None:
42
+ self.kp_driving_initial = kp_driving
43
+ kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
44
+ kp_driving_initial=self.kp_driving_initial)
45
+ dense_motion = self.dense_motion_network(source_image=source_tensor,
46
+ kp_driving=kp_norm,
47
+ kp_source=kp_source, bg_param=None,
48
+ dropout_flag=False)
49
+ out = self.inpainting(source_tensor, dense_motion)
50
+ out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
51
+ out = img_as_ubyte(out)
52
+ out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
53
+
54
+ return out
55
+
56
+
57
+ if __name__ == "__main__":
58
+ tpsmm = TPSMM()
59
+ source_image = cv2.imread(os.path.join(pwd, "../assets/source1.png"))
60
+ cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4")
61
+
62
+ source_tensor, kp_source = tpsmm.process_source(source_image)
63
+
64
+ while True:
65
+ ret, frame = cap.read()
66
+ if frame is None:
67
+ break
68
+
69
+ output = tpsmm.gen_image(frame, source_tensor, kp_source)
70
+ cv2.imshow("output", output)
71
+ key = cv2.waitKey(1) & 0xFF
72
+ if key == ord("q"):
73
+ break
74
+ # cv2.imwrite("./tmp.jpg", output)
75
+
76
+ cv2.destroyAllWindows()
77
+
78
+
79
+
80
+
TPSMM/predict.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, "stylegan-encoder")
4
+ import tempfile
5
+ import warnings
6
+ import imageio
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.animation as animation
10
+ from skimage.transform import resize
11
+ from skimage import img_as_ubyte
12
+ import torch
13
+ import torchvision.transforms as transforms
14
+ import dlib
15
+ from cog import BasePredictor, Path, Input
16
+
17
+ from demo import load_checkpoints
18
+ from demo import make_animation
19
+ from ffhq_dataset.face_alignment import image_align
20
+ from ffhq_dataset.landmarks_detector import LandmarksDetector
21
+
22
+
23
+ warnings.filterwarnings("ignore")
24
+
25
+
26
+ PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
27
+ LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")
28
+
29
+
30
+ class Predictor(BasePredictor):
31
+ def setup(self):
32
+
33
+ self.device = torch.device("cuda:0")
34
+ datasets = ["vox", "taichi", "ted", "mgif"]
35
+ (
36
+ self.inpainting,
37
+ self.kp_detector,
38
+ self.dense_motion_network,
39
+ self.avd_network,
40
+ ) = ({}, {}, {}, {})
41
+ for d in datasets:
42
+ (
43
+ self.inpainting[d],
44
+ self.kp_detector[d],
45
+ self.dense_motion_network[d],
46
+ self.avd_network[d],
47
+ ) = load_checkpoints(
48
+ config_path=f"config/{d}-384.yaml"
49
+ if d == "ted"
50
+ else f"config/{d}-256.yaml",
51
+ checkpoint_path=f"checkpoints/{d}.pth.tar",
52
+ device=self.device,
53
+ )
54
+
55
+ def predict(
56
+ self,
57
+ source_image: Path = Input(
58
+ description="Input source image.",
59
+ ),
60
+ driving_video: Path = Input(
61
+ description="Choose a micromotion.",
62
+ ),
63
+ dataset_name: str = Input(
64
+ choices=["vox", "taichi", "ted", "mgif"],
65
+ default="vox",
66
+ description="Choose a dataset.",
67
+ ),
68
+ ) -> Path:
69
+
70
+ predict_mode = "relative" # ['standard', 'relative', 'avd']
71
+ # find_best_frame = False
72
+
73
+ pixel = 384 if dataset_name == "ted" else 256
74
+
75
+ if dataset_name == "vox":
76
+ # first run face alignment
77
+ align_image(str(source_image), 'aligned.png')
78
+ source_image = imageio.imread('aligned.png')
79
+ else:
80
+ source_image = imageio.imread(str(source_image))
81
+ reader = imageio.get_reader(str(driving_video))
82
+ fps = reader.get_meta_data()["fps"]
83
+ source_image = resize(source_image, (pixel, pixel))[..., :3]
84
+
85
+ driving_video = []
86
+ try:
87
+ for im in reader:
88
+ driving_video.append(im)
89
+ except RuntimeError:
90
+ pass
91
+ reader.close()
92
+
93
+ driving_video = [
94
+ resize(frame, (pixel, pixel))[..., :3] for frame in driving_video
95
+ ]
96
+
97
+ inpainting, kp_detector, dense_motion_network, avd_network = (
98
+ self.inpainting[dataset_name],
99
+ self.kp_detector[dataset_name],
100
+ self.dense_motion_network[dataset_name],
101
+ self.avd_network[dataset_name],
102
+ )
103
+
104
+ predictions = make_animation(
105
+ source_image,
106
+ driving_video,
107
+ inpainting,
108
+ kp_detector,
109
+ dense_motion_network,
110
+ avd_network,
111
+ device="cuda:0",
112
+ mode=predict_mode,
113
+ )
114
+
115
+ # save resulting video
116
+ out_path = Path(tempfile.mkdtemp()) / "output.mp4"
117
+ imageio.mimsave(
118
+ str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps
119
+ )
120
+ return out_path
121
+
122
+
123
+ def align_image(raw_img_path, aligned_face_path):
124
+ for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
125
+ image_align(raw_img_path, aligned_face_path, face_landmarks)
TPSMM/pretrained/vox.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52ad8c848e2a1d91b621de96fea83faf57ce3b8c1c06424e317f4df1d3998204
3
+ size 350993469
TPSMM/reconstruction.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from logger import Logger, Visualizer
6
+ import numpy as np
7
+ import imageio
8
+
9
+
10
+ def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
11
+ png_dir = os.path.join(log_dir, 'reconstruction/png')
12
+ log_dir = os.path.join(log_dir, 'reconstruction')
13
+
14
+ if checkpoint is not None:
15
+ Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
16
+ bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
17
+ else:
18
+ raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
19
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
20
+
21
+ if not os.path.exists(log_dir):
22
+ os.makedirs(log_dir)
23
+
24
+ if not os.path.exists(png_dir):
25
+ os.makedirs(png_dir)
26
+
27
+ loss_list = []
28
+
29
+ inpainting_network.eval()
30
+ kp_detector.eval()
31
+ dense_motion_network.eval()
32
+ if bg_predictor:
33
+ bg_predictor.eval()
34
+
35
+ for it, x in tqdm(enumerate(dataloader)):
36
+ with torch.no_grad():
37
+ predictions = []
38
+ visualizations = []
39
+ if torch.cuda.is_available():
40
+ x['video'] = x['video'].cuda()
41
+ kp_source = kp_detector(x['video'][:, :, 0])
42
+ for frame_idx in range(x['video'].shape[2]):
43
+ source = x['video'][:, :, 0]
44
+ driving = x['video'][:, :, frame_idx]
45
+ kp_driving = kp_detector(driving)
46
+ bg_params = None
47
+ if bg_predictor:
48
+ bg_params = bg_predictor(source, driving)
49
+
50
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
51
+ kp_source=kp_source, bg_param = bg_params,
52
+ dropout_flag = False)
53
+ out = inpainting_network(source, dense_motion)
54
+ out['kp_source'] = kp_source
55
+ out['kp_driving'] = kp_driving
56
+
57
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
58
+
59
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
60
+ driving=driving, out=out)
61
+ visualizations.append(visualization)
62
+ loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
63
+
64
+ loss_list.append(loss)
65
+ # print(np.mean(loss_list))
66
+ predictions = np.concatenate(predictions, axis=1)
67
+ imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
68
+
69
+ print("Reconstruction loss: %s" % np.mean(loss_list))
TPSMM/requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cffi==1.14.6
2
+ cycler==0.10.0
3
+ decorator==5.1.0
4
+ face-alignment==1.3.5
5
+ imageio==2.9.0
6
+ imageio-ffmpeg==0.4.5
7
+ kiwisolver==1.3.2
8
+ matplotlib==3.4.3
9
+ networkx==2.6.3
10
+ numpy==1.20.3
11
+ pandas==1.3.3
12
+ Pillow==8.3.2
13
+ pycparser==2.20
14
+ pyparsing==2.4.7
15
+ python-dateutil==2.8.2
16
+ pytz==2021.1
17
+ PyWavelets==1.1.1
18
+ PyYAML==5.4.1
19
+ scikit-image==0.18.3
20
+ scikit-learn==1.0
21
+ scipy==1.7.1
22
+ six==1.16.0
23
+ torch==1.10.0+cu113
24
+ torchvision==0.11.0+cu113
25
+ tqdm==4.62.3
TPSMM/run.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+
4
+ import os, sys
5
+ import yaml
6
+ from argparse import ArgumentParser
7
+ from time import gmtime, strftime
8
+ from shutil import copy
9
+ from frames_dataset import FramesDataset
10
+
11
+ from modules.inpainting_network import InpaintingNetwork
12
+ from modules.keypoint_detector import KPDetector
13
+ from modules.bg_motion_predictor import BGMotionPredictor
14
+ from modules.dense_motion import DenseMotionNetwork
15
+ from modules.avd_network import AVDNetwork
16
+ import torch
17
+ from train import train
18
+ from train_avd import train_avd
19
+ from reconstruction import reconstruction
20
+ import os
21
+
22
+
23
+ if __name__ == "__main__":
24
+
25
+ if sys.version_info[0] < 3:
26
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
27
+
28
+ parser = ArgumentParser()
29
+ parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
30
+ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
31
+ parser.add_argument("--log_dir", default='log', help="path to log into")
32
+ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
33
+ parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))),
34
+ help="Names of the devices comma separated.")
35
+
36
+ opt = parser.parse_args()
37
+ with open(opt.config) as f:
38
+ config = yaml.load(f)
39
+
40
+ if opt.checkpoint is not None:
41
+ log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
42
+ else:
43
+ log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
44
+ log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
45
+
46
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
47
+ **config['model_params']['common_params'])
48
+
49
+ if torch.cuda.is_available():
50
+ cuda_device = torch.device('cuda:'+str(opt.device_ids[0]))
51
+ inpainting.to(cuda_device)
52
+
53
+ kp_detector = KPDetector(**config['model_params']['common_params'])
54
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
55
+ **config['model_params']['dense_motion_params'])
56
+
57
+ if torch.cuda.is_available():
58
+ kp_detector.to(opt.device_ids[0])
59
+ dense_motion_network.to(opt.device_ids[0])
60
+
61
+ bg_predictor = None
62
+ if (config['model_params']['common_params']['bg']):
63
+ bg_predictor = BGMotionPredictor()
64
+ if torch.cuda.is_available():
65
+ bg_predictor.to(opt.device_ids[0])
66
+
67
+ avd_network = None
68
+ if opt.mode == "train_avd":
69
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
70
+ **config['model_params']['avd_network_params'])
71
+ if torch.cuda.is_available():
72
+ avd_network.to(opt.device_ids[0])
73
+
74
+ dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params'])
75
+
76
+ if not os.path.exists(log_dir):
77
+ os.makedirs(log_dir)
78
+ if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
79
+ copy(opt.config, log_dir)
80
+
81
+ if opt.mode == 'train':
82
+ print("Training...")
83
+ train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
84
+ elif opt.mode == 'train_avd':
85
+ print("Training Animation via Disentaglement...")
86
+ train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset)
87
+ elif opt.mode == 'reconstruction':
88
+ print("Reconstruction...")
89
+ reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
TPSMM/tmp.jpg ADDED
TPSMM/tmp.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4")
5
+ while True:
6
+ ret, frame = cap.read()
7
+ if frame is None:
8
+ break
9
+ cv2.imshow("output", frame)
10
+ key = cv2.waitKey(1) & 0xff
11
+ if key == ord("q"):
12
+ break
13
+
14
+ cv2.destroyAllWindows()
TPSMM/train.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from logger import Logger
5
+ from modules.model import GeneratorFullModel
6
+ from torch.optim.lr_scheduler import MultiStepLR
7
+ from torch.nn.utils import clip_grad_norm_
8
+ from frames_dataset import DatasetRepeater
9
+ import math
10
+
11
+ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
12
+ train_params = config['train_params']
13
+ optimizer = torch.optim.Adam(
14
+ [{'params': list(inpainting_network.parameters()) +
15
+ list(dense_motion_network.parameters()) +
16
+ list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
17
+
18
+ optimizer_bg_predictor = None
19
+ if bg_predictor:
20
+ optimizer_bg_predictor = torch.optim.Adam(
21
+ [{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}],
22
+ lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
23
+
24
+ if checkpoint is not None:
25
+ start_epoch = Logger.load_cpk(
26
+ checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network,
27
+ kp_detector = kp_detector, bg_predictor = bg_predictor,
28
+ optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor)
29
+ print('load success:', start_epoch)
30
+ start_epoch += 1
31
+ else:
32
+ start_epoch = 0
33
+
34
+ scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1,
35
+ last_epoch=start_epoch - 1)
36
+ if bg_predictor:
37
+ scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'],
38
+ gamma=0.1, last_epoch=start_epoch - 1)
39
+
40
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
41
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
42
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
43
+ num_workers=train_params['dataloader_workers'], drop_last=True)
44
+
45
+ generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params)
46
+
47
+ if torch.cuda.is_available():
48
+ generator_full = torch.nn.DataParallel(generator_full).cuda()
49
+
50
+ bg_start = train_params['bg_start']
51
+
52
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
53
+ checkpoint_freq=train_params['checkpoint_freq']) as logger:
54
+ for epoch in trange(start_epoch, train_params['num_epochs']):
55
+ for x in dataloader:
56
+ if(torch.cuda.is_available()):
57
+ x['driving'] = x['driving'].cuda()
58
+ x['source'] = x['source'].cuda()
59
+
60
+ losses_generator, generated = generator_full(x, epoch)
61
+ loss_values = [val.mean() for val in losses_generator.values()]
62
+ loss = sum(loss_values)
63
+ loss.backward()
64
+
65
+ clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf)
66
+ clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf)
67
+ if bg_predictor and epoch>=bg_start:
68
+ clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf)
69
+
70
+ optimizer.step()
71
+ optimizer.zero_grad()
72
+ if bg_predictor and epoch>=bg_start:
73
+ optimizer_bg_predictor.step()
74
+ optimizer_bg_predictor.zero_grad()
75
+
76
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
77
+ logger.log_iter(losses=losses)
78
+
79
+ scheduler_optimizer.step()
80
+ if bg_predictor:
81
+ scheduler_bg_predictor.step()
82
+
83
+ model_save = {
84
+ 'inpainting_network': inpainting_network,
85
+ 'dense_motion_network': dense_motion_network,
86
+ 'kp_detector': kp_detector,
87
+ 'optimizer': optimizer,
88
+ }
89
+ if bg_predictor and epoch>=bg_start:
90
+ model_save['bg_predictor'] = bg_predictor
91
+ model_save['optimizer_bg_predictor'] = optimizer_bg_predictor
92
+
93
+ logger.log_epoch(epoch, model_save, inp=x, out=generated)
94
+
TPSMM/train_avd.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from logger import Logger
5
+ from torch.optim.lr_scheduler import MultiStepLR
6
+ from frames_dataset import DatasetRepeater
7
+
8
+
9
+ def random_scale(kp_params, scale):
10
+ theta = torch.rand(kp_params['fg_kp'].shape[0], 2) * (2 * scale) + (1 - scale)
11
+ theta = torch.diag_embed(theta).unsqueeze(1).type(kp_params['fg_kp'].type())
12
+ new_kp_params = {'fg_kp': torch.matmul(theta, kp_params['fg_kp'].unsqueeze(-1)).squeeze(-1)}
13
+ return new_kp_params
14
+
15
+
16
+ def train_avd(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network,
17
+ avd_network, checkpoint, log_dir, dataset):
18
+ train_params = config['train_avd_params']
19
+
20
+ optimizer = torch.optim.Adam(avd_network.parameters(), lr=train_params['lr'], betas=(0.5, 0.999))
21
+
22
+ if checkpoint is not None:
23
+ Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
24
+ bg_predictor=bg_predictor, avd_network=avd_network,
25
+ dense_motion_network= dense_motion_network,optimizer_avd=optimizer)
26
+ start_epoch = 0
27
+ else:
28
+ raise AttributeError("Checkpoint should be specified for mode='train_avd'.")
29
+
30
+ scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1)
31
+
32
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
33
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
34
+
35
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
36
+ num_workers=train_params['dataloader_workers'], drop_last=True)
37
+
38
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
39
+ checkpoint_freq=train_params['checkpoint_freq']) as logger:
40
+ for epoch in trange(start_epoch, train_params['num_epochs']):
41
+ avd_network.train()
42
+ for x in dataloader:
43
+ with torch.no_grad():
44
+ kp_source = kp_detector(x['source'].cuda())
45
+ kp_driving_gt = kp_detector(x['driving'].cuda())
46
+ kp_driving_random = random_scale(kp_driving_gt, scale=train_params['random_scale'])
47
+ rec = avd_network(kp_source, kp_driving_random)
48
+
49
+ reconstruction_kp = train_params['lambda_shift'] * \
50
+ torch.abs(kp_driving_gt['fg_kp'] - rec['fg_kp']).mean()
51
+
52
+ loss_dict = {'rec_kp': reconstruction_kp}
53
+ loss = reconstruction_kp
54
+
55
+ loss.backward()
56
+ optimizer.step()
57
+ optimizer.zero_grad()
58
+
59
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in loss_dict.items()}
60
+ logger.log_iter(losses=losses)
61
+
62
+ # Visualization
63
+ avd_network.eval()
64
+ with torch.no_grad():
65
+ source = x['source'][:6].cuda()
66
+ driving = torch.cat([x['driving'][[0, 1]].cuda(), source[[2, 3, 2, 1]]], dim=0)
67
+ kp_source = kp_detector(source)
68
+ kp_driving = kp_detector(driving)
69
+
70
+ out = avd_network(kp_source, kp_driving)
71
+ kp_driving = out
72
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
73
+ kp_source=kp_source)
74
+ generated = inpainting_network(source, dense_motion)
75
+
76
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
77
+
78
+ scheduler.step(epoch)
79
+ model_save = {
80
+ 'inpainting_network': inpainting_network,
81
+ 'dense_motion_network': dense_motion_network,
82
+ 'kp_detector': kp_detector,
83
+ 'avd_network': avd_network,
84
+ 'optimizer_avd': optimizer
85
+ }
86
+ if bg_predictor :
87
+ model_save['bg_predictor'] = bg_predictor
88
+
89
+ logger.log_epoch(epoch, model_save,
90
+ inp={'source': source, 'driving': driving},
91
+ out=generated)
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import av
2
+ import sys
3
+ import numpy as np
4
+ import cv2
5
+ import streamlit as st
6
+ from PIL import Image
7
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
8
+
9
+ sys.path.insert(1, "./retinaface")
10
+ sys.path.insert(1, "./TPSMM/pkgs")
11
+ from tpsmm import TPSMM
12
+ from detect import Detect
13
+ from turn import get_ice_servers
14
+
15
+
16
+ def parse_roi_box_from_bbox(bbox, shape):
17
+ img_h, img_w = shape[:2]
18
+ left, top, right, bottom = bbox[:4]
19
+ old_size = (right - left + bottom - top) / 2
20
+ center_x = right - (right - left) / 2.0
21
+ center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14
22
+
23
+ size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0)
24
+
25
+ roi_box = [0] * 4
26
+ roi_box[0] = center_x - size / 2
27
+ roi_box[1] = center_y - size / 2
28
+ roi_box[2] = roi_box[0] + size
29
+ roi_box[3] = roi_box[1] + size
30
+
31
+ return roi_box
32
+
33
+ cache_key = "retinaface"
34
+ if cache_key in st.session_state:
35
+ detector = st.session_state[cache_key]
36
+ else:
37
+ detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864))
38
+ st.session_state[cache_key] = detector
39
+
40
+ cache_key = "tpsmm"
41
+ if cache_key in st.session_state:
42
+ generator = st.session_state[cache_key]
43
+ else:
44
+ generator = TPSMM()
45
+ st.session_state[cache_key] = generator
46
+
47
+
48
+ @st.cache_resource # type: ignore
49
+ def get_images():
50
+ images = [
51
+ cv2.imread("assets/0.jpg"),
52
+ cv2.imread("assets/1.jpg"),
53
+ cv2.imread("assets/2.jpg"),
54
+ cv2.imread("assets/3.jpg"),
55
+ ]
56
+ item_list = [str(i) for i in range(len(images))]
57
+ images = [generator.process_source(src_img) for src_img in images]
58
+
59
+ return dict(zip(item_list, images))
60
+ images = get_images()
61
+ user_option = st.selectbox("Choose an item", list(images.keys()))
62
+
63
+ uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg'])
64
+ @st.cache_resource
65
+ def process_file(uploaded_file):
66
+ img = Image.open(uploaded_file)
67
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
68
+ dets = detector(img)
69
+ for i, b in enumerate(dets):
70
+ bbox = parse_roi_box_from_bbox(b[:4], img.shape)
71
+ bbox = [int(i) for i in bbox]
72
+
73
+ face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
74
+ # cv2.imwrite("./tmp.jpg", face_img)
75
+ return generator.process_source(face_img)
76
+
77
+ return None
78
+ if uploaded_file is not None:
79
+ uploaded_file = process_file(uploaded_file)
80
+
81
+ def callback(frame: av.VideoFrame) -> av.VideoFrame:
82
+ img = frame.to_ndarray(format="bgr24")
83
+
84
+ try:
85
+ dets = detector(img)
86
+ output = None
87
+ for i, b in enumerate(dets):
88
+ text = "{:.4f}".format(b[4])
89
+ b = b.astype(np.int32)
90
+ cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
91
+ bbox = parse_roi_box_from_bbox(b[:4], img.shape)
92
+ bbox = [int(i) for i in bbox]
93
+ cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
94
+
95
+ face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
96
+ if uploaded_file is None:
97
+ source_tensor, kp_source = images[user_option]
98
+ else:
99
+ source_tensor, kp_source = uploaded_file
100
+ output = generator.gen_image(face_img, source_tensor, kp_source)
101
+
102
+ landm = b[5:15]
103
+ landm = landm.reshape((5, 2))
104
+ cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
105
+ cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
106
+ cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
107
+ cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
108
+ cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
109
+
110
+ if output is not None:
111
+ img[:256, :256] = output
112
+ except Exception as e:
113
+ print(e)
114
+
115
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
116
+
117
+ webrtc_streamer(
118
+ key="sample",
119
+ rtc_configuration={"iceServers": get_ice_servers()},
120
+ video_frame_callback=callback,
121
+ media_stream_constraints={"video": True, "audio": False},
122
+ )
assets/0.jpg ADDED
assets/1.jpg ADDED
assets/2.jpg ADDED
assets/3.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit-webrtc
2
+ twilio
3
+ altair<5
4
+ numpy==1.23.1
5
+ opencv-python==4.8.0.74
6
+ imutils
7
+ scikit-image==0.21.0
8
+ matplotlib==3.7.1
9
+ pyaml==23.5.9
10
+ tqdm
11
+ torch
12
+ torchvision
retinaface/change_batch_onnx.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+
3
+
4
+ model = onnx.load('weights/faceDetector_243_432_b1_sim.onnx')
5
+
6
+ # # for fixed batchsize
7
+ # model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 32
8
+ # model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 32
9
+ # model.graph.output[1].type.tensor_type.shape.dim[0].dim_value = 32
10
+ # model.graph.output[2].type.tensor_type.shape.dim[0].dim_value = 32
11
+ # onnx.save(model, 'faceDetector_640_b32.onnx')
12
+
13
+
14
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'batch' # for dynamic batchsize
15
+ model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = 'batch'
16
+ model.graph.output[1].type.tensor_type.shape.dim[0].dim_param = 'batch'
17
+ model.graph.output[2].type.tensor_type.shape.dim[0].dim_param = 'batch'
18
+ model.graph.output[3].type.tensor_type.shape.dim[0].dim_param = 'batch'
19
+ model.graph.output[4].type.tensor_type.shape.dim[0].dim_param = 'batch'
20
+ model.graph.output[5].type.tensor_type.shape.dim[0].dim_param = 'batch'
21
+ model.graph.output[6].type.tensor_type.shape.dim[0].dim_param = 'batch'
22
+ model.graph.output[7].type.tensor_type.shape.dim[0].dim_param = 'batch'
23
+ model.graph.output[8].type.tensor_type.shape.dim[0].dim_param = 'batch'
24
+ onnx.save(model, 'weights/faceDetector_243_432_batch_sim.onnx')
25
+
26
+
27
+ ####################################################
28
+ # SHow model onnx
29
+
30
+ # import onnxruntime as rt
31
+ # ort_session = rt.InferenceSession("faceDetector_180_320_batch_sim.onnx")
32
+ # print("====INPUT====")
33
+ # for i in ort_session.get_inputs():
34
+ # print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type))
35
+ # print("====OUTPUT====")
36
+ # for i in ort_session.get_outputs():
37
+ # print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type))
38
+
39
+ # import numpy as np
40
+ # input_name = ort_session.get_inputs()[0].name
41
+ # img = np.random.randn(4, 3, 180, 320).astype(np.float32)
42
+ # data = ort_session.run(None, {input_name: img})
43
+ # print("Done")
retinaface/convert_to_onnx.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python convert_to_onnx.py --network mobile0.25 --trained_model weights/mobilenet0.25_Final.pth
2
+ from __future__ import print_function
3
+ import os
4
+ import argparse
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ import numpy as np
8
+ from data import cfg_mnet, cfg_slim, cfg_rfb
9
+ from layers.functions.prior_box import PriorBox
10
+ from utils.nms.py_cpu_nms import py_cpu_nms
11
+ import cv2
12
+ from models.retinaface import RetinaFace
13
+ from models.net_slim import Slim
14
+ from models.net_rfb import RFB
15
+ from utils.box_utils import decode, decode_landm
16
+ from utils.timer import Timer
17
+
18
+
19
+ parser = argparse.ArgumentParser(description='Test')
20
+ parser.add_argument('-m', '--trained_model', default='./weights/RBF_Final.pth',
21
+ type=str, help='Trained state_dict file path to open')
22
+ parser.add_argument('--network', default='RFB', help='Backbone network mobile0.25 or slim or RFB')
23
+ parser.add_argument('--long_side', default=320, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)')
24
+ parser.add_argument('--cpu', action="store_true", help='Use cpu inference')
25
+
26
+ args = parser.parse_args()
27
+
28
+
29
+ def check_keys(model, pretrained_state_dict):
30
+ ckpt_keys = set(pretrained_state_dict.keys())
31
+ model_keys = set(model.state_dict().keys())
32
+ used_pretrained_keys = model_keys & ckpt_keys
33
+ unused_pretrained_keys = ckpt_keys - model_keys
34
+ missing_keys = model_keys - ckpt_keys
35
+ print('Missing keys:{}'.format(len(missing_keys)))
36
+ print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
37
+ print('Used keys:{}'.format(len(used_pretrained_keys)))
38
+ assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
39
+ return True
40
+
41
+
42
+ def remove_prefix(state_dict, prefix):
43
+ ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
44
+ print('remove prefix \'{}\''.format(prefix))
45
+ f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
46
+ return {f(key): value for key, value in state_dict.items()}
47
+
48
+
49
+ def load_model(model, pretrained_path, load_to_cpu):
50
+ print('Loading pretrained model from {}'.format(pretrained_path))
51
+ if load_to_cpu:
52
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
53
+ else:
54
+ device = torch.cuda.current_device()
55
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
56
+ if "state_dict" in pretrained_dict.keys():
57
+ pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
58
+ else:
59
+ pretrained_dict = remove_prefix(pretrained_dict, 'module.')
60
+ check_keys(model, pretrained_dict)
61
+ model.load_state_dict(pretrained_dict, strict=False)
62
+ return model
63
+
64
+
65
+ if __name__ == '__main__':
66
+ torch.set_grad_enabled(False)
67
+
68
+ cfg = None
69
+ net = None
70
+ # long_side = int(args.long_side)
71
+ net_inshape = (243, 432)
72
+ device = torch.device("cpu" if args.cpu else "cuda")
73
+ print(device)
74
+ if args.network == "mobile0.25":
75
+ cfg = cfg_mnet
76
+ # net_inshape = (long_side, long_side) # h, w
77
+ priorbox = PriorBox(cfg, image_size=net_inshape)
78
+ priors = priorbox.forward()
79
+ prior_data = priors.to(device)
80
+ net = RetinaFace(cfg=cfg, phase='test')
81
+ elif args.network == "slim":
82
+ cfg = cfg_slim
83
+ net = Slim(cfg = cfg, phase = 'test')
84
+ elif args.network == "RFB":
85
+ cfg = cfg_rfb
86
+ net = RFB(cfg = cfg, phase = 'test')
87
+ else:
88
+ print("Don't support network!")
89
+ exit(0)
90
+
91
+ # load weight
92
+ net = load_model(net, args.trained_model, args.cpu)
93
+ net.eval()
94
+ print('Finished loading model!')
95
+ print(net)
96
+ net = net.to(device)
97
+
98
+ ##################export###############
99
+ output_onnx = f'weights/faceDetector_{net_inshape[0]}_{net_inshape[1]}_b1.onnx'
100
+ print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
101
+ input_names = ['input_1']
102
+ output_names = ['box_1', 'box_2', 'box_3']
103
+
104
+ # import torch.onnx.symbolic_opset9 as onnx_symbolic
105
+ # def upsample_nearest2d(g, input, output_size, *args):
106
+ # # Currently, TRT 5.1/6.0/7.0 ONNX Parser does not support all ONNX ops
107
+ # # needed to support dynamic upsampling ONNX forumlation
108
+ # # Here we hardcode scale=2 as a temporary workaround
109
+ # scales = g.op("Constant", value_t=torch.tensor([1., 1., 2., 2.]))
110
+ # return g.op("Resize", input, scales, mode_s="nearest")
111
+
112
+
113
+ # onnx_symbolic.upsample_nearest2d = upsample_nearest2d
114
+
115
+ # import io
116
+ # onnx_bytes = io.BytesIO()
117
+ # zero_input = torch.zeros([1, 3, net_inshape[0], net_inshape[1]]).cuda()
118
+ # dynamic_axes = {input_names[0]: {0:'batch'}}
119
+ # for _, name in enumerate(output_names):
120
+ # dynamic_axes[name] = dynamic_axes[input_names[0]]
121
+ # extra_args = {'opset_version': 10, 'verbose': False,
122
+ # 'input_names': input_names, 'output_names': output_names,
123
+ # 'dynamic_axes': dynamic_axes}
124
+ # torch.onnx.export(net, zero_input, onnx_bytes, **extra_args)
125
+ # with open(output_onnx, 'wb') as out:
126
+ # out.write(onnx_bytes.getvalue())
127
+
128
+ inputs = torch.randn(1, 3, net_inshape[0], net_inshape[1]).to(device)
129
+ torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False, opset_version=9,
130
+ input_names=input_names, output_names=output_names)
131
+ ################end###############
132
+
133
+
134
+
135
+
retinaface/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .wider_face import WiderFaceDetection, detection_collate
2
+ from .data_augment import *
3
+ from .config import *
retinaface/data/config.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ cfg_mnet = {
3
+ 'name': 'mobilenet0.25',
4
+ 'min_sizes': [[10, 20], [32, 64], [128, 256]],
5
+ 'steps': [8, 16, 32],
6
+ 'variance': [0.1, 0.2],
7
+ 'clip': False,
8
+ 'loc_weight': 2.0,
9
+ 'gpu_train': True,
10
+ 'batch_size': 32,
11
+ 'ngpu': 1,
12
+ 'epoch': 250,
13
+ 'decay1': 190,
14
+ 'decay2': 220,
15
+ 'image_size': 300,
16
+ "net_inshape": (320, 320), # h, w
17
+ 'pretrain': False,
18
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
19
+ 'in_channel': 32,
20
+ 'out_channel': 64
21
+ }
22
+
23
+ cfg_slim = {
24
+ 'name': 'slim',
25
+ 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
26
+ 'steps': [8, 16, 32, 64],
27
+ 'variance': [0.1, 0.2],
28
+ 'clip': False,
29
+ 'loc_weight': 2.0,
30
+ 'gpu_train': True,
31
+ 'batch_size': 32,
32
+ 'ngpu': 1,
33
+ 'epoch': 250,
34
+ 'decay1': 190,
35
+ 'decay2': 220,
36
+ 'image_size': 300
37
+ }
38
+
39
+ cfg_rfb = {
40
+ 'name': 'RFB',
41
+ 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
42
+ 'steps': [8, 16, 32, 64],
43
+ 'variance': [0.1, 0.2],
44
+ 'clip': False,
45
+ 'loc_weight': 2.0,
46
+ 'gpu_train': True,
47
+ 'batch_size': 32,
48
+ 'ngpu': 1,
49
+ 'epoch': 250,
50
+ 'decay1': 190,
51
+ 'decay2': 220,
52
+ 'image_size': 300
53
+ }
54
+
55
+
retinaface/data/data_augment.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import random
4
+ from utils.box_utils import matrix_iof
5
+
6
+
7
+ def _crop(image, boxes, labels, landm, img_dim):
8
+ height, width, _ = image.shape
9
+ pad_image_flag = True
10
+
11
+ for _ in range(250):
12
+ if random.uniform(0, 1) <= 0.2:
13
+ scale = 1.0
14
+ else:
15
+ scale = random.uniform(0.3, 1.0)
16
+ # PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
17
+ # scale = random.choice(PRE_SCALES)
18
+ short_side = min(width, height)
19
+ w = int(scale * short_side)
20
+ h = w
21
+
22
+ if width == w:
23
+ l = 0
24
+ else:
25
+ l = random.randrange(width - w)
26
+ if height == h:
27
+ t = 0
28
+ else:
29
+ t = random.randrange(height - h)
30
+ roi = np.array((l, t, l + w, t + h))
31
+
32
+ value = matrix_iof(boxes, roi[np.newaxis])
33
+ flag = (value >= 1)
34
+ if not flag.any():
35
+ continue
36
+
37
+ centers = (boxes[:, :2] + boxes[:, 2:]) / 2
38
+ mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
39
+ boxes_t = boxes[mask_a].copy()
40
+ labels_t = labels[mask_a].copy()
41
+ landms_t = landm[mask_a].copy()
42
+ landms_t = landms_t.reshape([-1, 5, 2])
43
+
44
+ if boxes_t.shape[0] == 0:
45
+ continue
46
+
47
+ image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
48
+
49
+ boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
50
+ boxes_t[:, :2] -= roi[:2]
51
+ boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
52
+ boxes_t[:, 2:] -= roi[:2]
53
+
54
+ # landm
55
+ landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
56
+ landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
57
+ landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
58
+ landms_t = landms_t.reshape([-1, 10])
59
+
60
+
61
+ # make sure that the cropped image contains at least one face > 16 pixel at training image scale
62
+ b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
63
+ b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
64
+ mask_b = np.minimum(b_w_t, b_h_t) > 5
65
+ boxes_t = boxes_t[mask_b]
66
+ labels_t = labels_t[mask_b]
67
+ landms_t = landms_t[mask_b]
68
+
69
+ if boxes_t.shape[0] == 0:
70
+ continue
71
+
72
+ pad_image_flag = False
73
+
74
+ return image_t, boxes_t, labels_t, landms_t, pad_image_flag
75
+ return image, boxes, labels, landm, pad_image_flag
76
+
77
+
78
+ def _distort(image):
79
+
80
+ def _convert(image, alpha=1, beta=0):
81
+ tmp = image.astype(float) * alpha + beta
82
+ tmp[tmp < 0] = 0
83
+ tmp[tmp > 255] = 255
84
+ image[:] = tmp
85
+
86
+ image = image.copy()
87
+
88
+ if random.randrange(2):
89
+
90
+ #brightness distortion
91
+ if random.randrange(2):
92
+ _convert(image, beta=random.uniform(-32, 32))
93
+
94
+ #contrast distortion
95
+ if random.randrange(2):
96
+ _convert(image, alpha=random.uniform(0.5, 1.5))
97
+
98
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
99
+
100
+ #saturation distortion
101
+ if random.randrange(2):
102
+ _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
103
+
104
+ #hue distortion
105
+ if random.randrange(2):
106
+ tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
107
+ tmp %= 180
108
+ image[:, :, 0] = tmp
109
+
110
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
111
+
112
+ else:
113
+
114
+ #brightness distortion
115
+ if random.randrange(2):
116
+ _convert(image, beta=random.uniform(-32, 32))
117
+
118
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
119
+
120
+ #saturation distortion
121
+ if random.randrange(2):
122
+ _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
123
+
124
+ #hue distortion
125
+ if random.randrange(2):
126
+ tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
127
+ tmp %= 180
128
+ image[:, :, 0] = tmp
129
+
130
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
131
+
132
+ #contrast distortion
133
+ if random.randrange(2):
134
+ _convert(image, alpha=random.uniform(0.5, 1.5))
135
+
136
+ return image
137
+
138
+
139
+ def _expand(image, boxes, fill, p):
140
+ if random.randrange(2):
141
+ return image, boxes
142
+
143
+ height, width, depth = image.shape
144
+
145
+ scale = random.uniform(1, p)
146
+ w = int(scale * width)
147
+ h = int(scale * height)
148
+
149
+ left = random.randint(0, w - width)
150
+ top = random.randint(0, h - height)
151
+
152
+ boxes_t = boxes.copy()
153
+ boxes_t[:, :2] += (left, top)
154
+ boxes_t[:, 2:] += (left, top)
155
+ expand_image = np.empty(
156
+ (h, w, depth),
157
+ dtype=image.dtype)
158
+ expand_image[:, :] = fill
159
+ expand_image[top:top + height, left:left + width] = image
160
+ image = expand_image
161
+
162
+ return image, boxes_t
163
+
164
+
165
+ def _mirror(image, boxes, landms):
166
+ _, width, _ = image.shape
167
+ if random.randrange(2):
168
+ image = image[:, ::-1]
169
+ boxes = boxes.copy()
170
+ boxes[:, 0::2] = width - boxes[:, 2::-2]
171
+
172
+ # landm
173
+ landms = landms.copy()
174
+ landms = landms.reshape([-1, 5, 2])
175
+ landms[:, :, 0] = width - landms[:, :, 0]
176
+ tmp = landms[:, 1, :].copy()
177
+ landms[:, 1, :] = landms[:, 0, :]
178
+ landms[:, 0, :] = tmp
179
+ tmp1 = landms[:, 4, :].copy()
180
+ landms[:, 4, :] = landms[:, 3, :]
181
+ landms[:, 3, :] = tmp1
182
+ landms = landms.reshape([-1, 10])
183
+
184
+ return image, boxes, landms
185
+
186
+
187
+ def _pad_to_square(image, rgb_mean, pad_image_flag):
188
+ if not pad_image_flag:
189
+ return image
190
+ height, width, _ = image.shape
191
+ long_side = max(width, height)
192
+ image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
193
+ image_t[:, :] = rgb_mean
194
+ image_t[0:0 + height, 0:0 + width] = image
195
+ return image_t
196
+
197
+
198
+ def _resize_subtract_mean(image, insize, rgb_mean):
199
+ interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
200
+ interp_method = interp_methods[random.randrange(5)]
201
+ image = cv2.resize(image, (insize, insize), interpolation=interp_method)
202
+ image = image.astype(np.float32)
203
+ image -= rgb_mean
204
+ return image.transpose(2, 0, 1)
205
+
206
+
207
+ class preproc(object):
208
+
209
+ def __init__(self, img_dim, rgb_means):
210
+ self.img_dim = img_dim
211
+ self.rgb_means = rgb_means
212
+
213
+ def __call__(self, image, targets):
214
+ assert targets.shape[0] > 0, "this image does not have gt"
215
+
216
+ boxes = targets[:, :4].copy()
217
+ labels = targets[:, -1].copy()
218
+ landm = targets[:, 4:-1].copy()
219
+
220
+ image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
221
+ image_t = _distort(image_t)
222
+ image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
223
+ image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
224
+ height, width, _ = image_t.shape
225
+ image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
226
+ boxes_t[:, 0::2] /= width
227
+ boxes_t[:, 1::2] /= height
228
+
229
+ landm_t[:, 0::2] /= width
230
+ landm_t[:, 1::2] /= height
231
+
232
+ labels_t = np.expand_dims(labels_t, 1)
233
+ targets_t = np.hstack((boxes_t, landm_t, labels_t))
234
+
235
+ return image_t, targets_t
retinaface/data/wider_face.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import sys
4
+ import torch
5
+ import torch.utils.data as data
6
+ import cv2
7
+ import numpy as np
8
+
9
+ class WiderFaceDetection(data.Dataset):
10
+ def __init__(self, txt_path, preproc=None):
11
+ self.preproc = preproc
12
+ self.imgs_path = []
13
+ self.words = []
14
+ f = open(txt_path,'r')
15
+ lines = f.readlines()
16
+ isFirst = True
17
+ labels = []
18
+ for line in lines:
19
+ line = line.rstrip()
20
+ if line.startswith('#'):
21
+ if isFirst is True:
22
+ isFirst = False
23
+ else:
24
+ labels_copy = labels.copy()
25
+ self.words.append(labels_copy)
26
+ labels.clear()
27
+ path = line[2:]
28
+ path = txt_path.replace('label.txt','images/') + path
29
+ self.imgs_path.append(path)
30
+ else:
31
+ line = line.split(' ')
32
+ label = [float(x) for x in line]
33
+ labels.append(label)
34
+
35
+ self.words.append(labels)
36
+
37
+ def __len__(self):
38
+ return len(self.imgs_path)
39
+
40
+ def __getitem__(self, index):
41
+ img = cv2.imread(self.imgs_path[index])
42
+ height, width, _ = img.shape
43
+
44
+ labels = self.words[index]
45
+ annotations = np.zeros((0, 15))
46
+ if len(labels) == 0:
47
+ return annotations
48
+ for idx, label in enumerate(labels):
49
+ annotation = np.zeros((1, 15))
50
+ # bbox
51
+ annotation[0, 0] = label[0] # x1
52
+ annotation[0, 1] = label[1] # y1
53
+ annotation[0, 2] = label[0] + label[2] # x2
54
+ annotation[0, 3] = label[1] + label[3] # y2
55
+
56
+ # landmarks
57
+ annotation[0, 4] = label[4] # l0_x
58
+ annotation[0, 5] = label[5] # l0_y
59
+ annotation[0, 6] = label[7] # l1_x
60
+ annotation[0, 7] = label[8] # l1_y
61
+ annotation[0, 8] = label[10] # l2_x
62
+ annotation[0, 9] = label[11] # l2_y
63
+ annotation[0, 10] = label[13] # l3_x
64
+ annotation[0, 11] = label[14] # l3_y
65
+ annotation[0, 12] = label[16] # l4_x
66
+ annotation[0, 13] = label[17] # l4_y
67
+ if (annotation[0, 4]<0):
68
+ annotation[0, 14] = -1
69
+ else:
70
+ annotation[0, 14] = 1
71
+
72
+ annotations = np.append(annotations, annotation, axis=0)
73
+ target = np.array(annotations)
74
+ if self.preproc is not None:
75
+ img, target = self.preproc(img, target)
76
+
77
+ return torch.from_numpy(img), target
78
+
79
+ def detection_collate(batch):
80
+ """Custom collate fn for dealing with batches of images that have a different
81
+ number of associated object annotations (bounding boxes).
82
+
83
+ Arguments:
84
+ batch: (tuple) A tuple of tensor images and lists of annotations
85
+
86
+ Return:
87
+ A tuple containing:
88
+ 1) (tensor) batch of images stacked on their 0 dim
89
+ 2) (list of tensors) annotations for a given image are stacked on 0 dim
90
+ """
91
+ targets = []
92
+ imgs = []
93
+ for _, sample in enumerate(batch):
94
+ for _, tup in enumerate(sample):
95
+ if torch.is_tensor(tup):
96
+ imgs.append(tup)
97
+ elif isinstance(tup, type(np.empty(0))):
98
+ annos = torch.from_numpy(tup).float()
99
+ targets.append(annos)
100
+
101
+ return (torch.stack(imgs, 0), targets)
retinaface/detect.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import cv2
4
+ import numpy as np
5
+ import timeit
6
+
7
+ import imutils
8
+ from utils.infer_utils import load_model
9
+ from data import cfg_mnet as cfg
10
+ from models.retinaface import RetinaFace
11
+ from layers.functions.prior_box import PriorBox
12
+ from utils.box_utils import decode, decode_landm
13
+ from utils.nms.py_cpu_nms import py_cpu_nms
14
+ from utils.infer_utils import align_face
15
+ torch.set_grad_enabled(False)
16
+
17
+
18
+ class Detect:
19
+ def __init__(self, weight_path, net_inshape=(180, 320)):
20
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ self.net_inshape = net_inshape
22
+ im_height, im_width = net_inshape
23
+ self.box_scale = np.array([im_width, im_height] * 2)
24
+ self.lmk_scale = np.array([im_width, im_height] * 5)
25
+
26
+ priorbox = PriorBox(cfg, image_size=net_inshape)
27
+ priors = priorbox.forward()
28
+ self.prior_data = priors.to(self.device)
29
+ self.net = RetinaFace(cfg=cfg, phase='test')
30
+ self.net = load_model(self.net, weight_path, False)
31
+ self.net.eval()
32
+ self.net = self.net.to(self.device)
33
+
34
+
35
+ def _preprocess(self, image):
36
+ rgb_mean = (104, 117, 123) # bgr order
37
+ h, w = image.shape[:2]
38
+ dx = int(self.net_inshape[1] * h / self.net_inshape[0] - w)
39
+ dy = 0
40
+ if dx < 0:
41
+ dx = 0
42
+ dy = int(self.net_inshape[0] * w / self.net_inshape[1] - h)
43
+ img = cv2.copyMakeBorder(image, 0, dy, 0, dx, borderType=cv2.BORDER_CONSTANT, value=rgb_mean)
44
+ img = cv2.copyMakeBorder(img, 0, img.shape[0], 0, img.shape[1], borderType=cv2.BORDER_CONSTANT, value=rgb_mean)
45
+
46
+ h, w = img.shape[:2]
47
+ resize = float(self.net_inshape[1]) / float(w)
48
+ img = cv2.resize(img, self.net_inshape[::-1])
49
+ img = np.float32(img)
50
+ img -= rgb_mean
51
+
52
+ return img, resize
53
+
54
+
55
+ def __call__(self, img, verbose=False):
56
+ '''
57
+ bgr image
58
+ '''
59
+ t0 = timeit.default_timer()
60
+ img, resize = self._preprocess(img)
61
+ img = img.transpose(2, 0, 1)
62
+ img = torch.from_numpy(img).unsqueeze(0)
63
+ img = img.to(self.device)
64
+
65
+ t1 = timeit.default_timer()
66
+ loc, conf, landms = self.net(img)
67
+ loc = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 4) for i in loc]
68
+ loc = torch.cat(loc, dim=1)
69
+ conf = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 3) for i in conf]
70
+ conf = torch.cat(conf, dim=1)
71
+ landms = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 10) for i in landms]
72
+ landms = torch.cat(landms, dim=1)
73
+ conf = F.softmax(conf, dim=-1)
74
+
75
+ t2 = timeit.default_timer()
76
+ conf = conf[0]
77
+ scores = conf.squeeze(0).detach().cpu().numpy()[:, 1:]
78
+ scores = np.amax(scores, axis=1)
79
+
80
+ boxes = decode(loc[0], self.prior_data, cfg['variance']) # loc[0]
81
+ boxes = boxes.detach().cpu().numpy()
82
+ boxes = boxes * self.box_scale / resize
83
+
84
+ landms = decode_landm(landms[0], self.prior_data, cfg['variance'])
85
+ landms = landms.detach().cpu().numpy()
86
+ landms = landms * self.lmk_scale / resize
87
+
88
+ # ignore low scores
89
+ inds = np.where(scores > 0.02)[0]
90
+ boxes = boxes[inds]
91
+ landms = landms[inds]
92
+ scores = scores[inds]
93
+
94
+ # keep top-K before NMS
95
+ order = scores.argsort()[::-1][:5000]
96
+ boxes = boxes[order]
97
+ landms = landms[order]
98
+ scores = scores[order]
99
+
100
+ # do NMS
101
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
102
+ keep = py_cpu_nms(dets, 0.4)
103
+ # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
104
+ dets = dets[keep, :]
105
+ landms = landms[keep]
106
+
107
+ # keep top-K faster NMS
108
+ dets = dets[:750, :]
109
+ landms = landms[:750, :]
110
+
111
+ dets = np.concatenate((dets, landms), axis=1)
112
+ dets = dets[dets[:, 4] > 0.5]
113
+ dets = dets[np.argsort(dets, axis=0)[:, 0]]
114
+
115
+ t3 = timeit.default_timer()
116
+ if verbose:
117
+ print(t1 - t0, t2 - t1, t3 - t2)
118
+
119
+ return dets # (n, 15), box=0-3, cls=4, lmk=5-10
120
+
121
+
122
+ if __name__ == "__main__":
123
+ net_inshape = (486, 864) # h, w
124
+ model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape)
125
+ image_path = "/mnt/nvme0n1p2/datasets/face/dyno/mytelpay230626/mytelpay230626_raw/data_2nd/၁၀ကတန(နိုင်)၀၀၁၀၀၁/mytel_ekyc_1m2_65160cc1802b6183d87fca091cab4c2faa93a9b1614106b5911ca778_front_image.jpg"
126
+ img = cv2.imread(image_path)
127
+
128
+ dets = model(img)
129
+ for i, b in enumerate(dets):
130
+ text = "{:.4f}".format(b[4])
131
+ b = b.astype(np.int32)
132
+ landm = b[5:15]
133
+ landm = landm.reshape((5, 2))
134
+
135
+ alighed_face = align_face(img, landm.copy())
136
+ # cv2.imshow(str(i), alighed_face)
137
+
138
+ # landms
139
+ landm = landm.astype(np.int32)
140
+ cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
141
+ cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
142
+ cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
143
+ cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
144
+ cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
145
+
146
+ cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
147
+ cx = b[0]
148
+ cy = b[1] + 20
149
+ cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255))
150
+
151
+ cv2.imwrite("./output.jpg", img)
152
+
retinaface/detect_video_raw.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import imutils
4
+
5
+ from utils.fps import FPS
6
+ from utils.infer_utils import LoadStream, align_face
7
+ from detect import Detect
8
+
9
+
10
+ net_inshape = (486, 864) # h, w
11
+ model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape)
12
+ # dataloader = LoadStream("rtsp://admin:meditech123@192.168.100.90:555/")
13
+ dataloader = LoadStream("../30Shine_1.mp4")
14
+ fps = FPS().start()
15
+
16
+ for frame in dataloader:
17
+ # frame = imutils.resize(frame, width=640)
18
+ frame = frame.copy()
19
+ frame_raw = frame.copy()
20
+
21
+
22
+ dets = model(frame)
23
+ for i, b in enumerate(dets):
24
+ text = "{:.4f}".format(b[4])
25
+ b = b.astype(np.int32)
26
+ landm = b[5:15]
27
+ landm = landm.reshape((5, 2))
28
+
29
+ alighed_face = align_face(frame, landm.copy())
30
+ # cv2.imshow(str(i), alighed_face)
31
+
32
+ # landms
33
+ landm = landm.astype(np.int32)
34
+ cv2.circle(frame, tuple(landm[0]), 1, (0, 0, 255), 2)
35
+ cv2.circle(frame, tuple(landm[1]), 1, (0, 255, 255), 2)
36
+ cv2.circle(frame, tuple(landm[2]), 1, (255, 0, 255), 2)
37
+ cv2.circle(frame, tuple(landm[3]), 1, (0, 255, 0), 2)
38
+ cv2.circle(frame, tuple(landm[4]), 1, (255, 0, 0), 2)
39
+
40
+ cv2.rectangle(frame, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
41
+ cx = b[0]
42
+ cy = b[1] + 20
43
+ cv2.putText(frame, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255))
44
+
45
+ fps.update()
46
+ text_fps = "FPS: {:.3f}".format(fps.get_fps_n())
47
+ cv2.putText(frame, text_fps, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
48
+ cv2.imshow("frame", imutils.resize(frame, width=1700))
49
+ key = cv2.waitKey(1) & 0xff
50
+ if key == ord("q"):
51
+ break
52
+ elif key == ord("c"):
53
+ while True:
54
+ cv2.imshow("frame", imutils.resize(frame, width=1700))
55
+ key = cv2.waitKey(1) & 0xff
56
+ if key == ord("q"):
57
+ break
58
+ # cv2.imwrite(f"{i}.jpg", alighed_face)
59
+ # i += 1
60
+ # # break
61
+
62
+ print(text_fps)
63
+ cv2.destroyAllWindows()
64
+ fps.stop()
65
+ print("Total FPS: {}".format(fps.fps()))
66
+ dataloader.close()
retinaface/layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functions import *
2
+ from .modules import *
retinaface/layers/functions/prior_box.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from itertools import product as product
3
+ import numpy as np
4
+ from math import ceil
5
+
6
+
7
+ class PriorBox(object):
8
+ def __init__(self, cfg, image_size=None, phase='train'):
9
+ super(PriorBox, self).__init__()
10
+ self.min_sizes = cfg['min_sizes']
11
+ self.steps = cfg['steps']
12
+ self.clip = cfg['clip']
13
+ self.image_size = image_size
14
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
15
+
16
+ def forward(self):
17
+ anchors = []
18
+ for k, f in enumerate(self.feature_maps):
19
+ min_sizes = self.min_sizes[k]
20
+ for i, j in product(range(f[0]), range(f[1])):
21
+ for min_size in min_sizes:
22
+ s_kx = min_size / self.image_size[1]
23
+ s_ky = min_size / self.image_size[0]
24
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
25
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
26
+ for cy, cx in product(dense_cy, dense_cx):
27
+ anchors += [cx, cy, s_kx, s_ky]
28
+
29
+ # back to torch land
30
+ output = torch.Tensor(anchors).view(-1, 4)
31
+ if self.clip:
32
+ output.clamp_(max=1, min=0)
33
+ return output
retinaface/layers/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .multibox_loss import MultiBoxLoss
2
+
3
+ __all__ = ['MultiBoxLoss']
retinaface/layers/modules/multibox_loss.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from utils.box_utils import match, log_sum_exp
6
+ from data import cfg_mnet
7
+ GPU = cfg_mnet['gpu_train']
8
+
9
+ class MultiBoxLoss(nn.Module):
10
+ """SSD Weighted Loss Function
11
+ Compute Targets:
12
+ 1) Produce Confidence Target Indices by matching ground truth boxes
13
+ with (default) 'priorboxes' that have jaccard index > threshold parameter
14
+ (default threshold: 0.5).
15
+ 2) Produce localization target by 'encoding' variance into offsets of ground
16
+ truth boxes and their matched 'priorboxes'.
17
+ 3) Hard negative mining to filter the excessive number of negative examples
18
+ that comes with using a large number of default bounding boxes.
19
+ (default negative:positive ratio 3:1)
20
+ Objective Loss:
21
+ L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
22
+ Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
23
+ weighted by α which is set to 1 by cross val.
24
+ Args:
25
+ c: class confidences,
26
+ l: predicted boxes,
27
+ g: ground truth boxes
28
+ N: number of matched default boxes
29
+ See: https://arxiv.org/pdf/1512.02325.pdf for more details.
30
+ """
31
+
32
+ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
33
+ super(MultiBoxLoss, self).__init__()
34
+ self.num_classes = num_classes
35
+ self.threshold = overlap_thresh
36
+ self.background_label = bkg_label
37
+ self.encode_target = encode_target
38
+ self.use_prior_for_matching = prior_for_matching
39
+ self.do_neg_mining = neg_mining
40
+ self.negpos_ratio = neg_pos
41
+ self.neg_overlap = neg_overlap
42
+ self.variance = [0.1, 0.2]
43
+
44
+ def forward(self, predictions, priors, targets):
45
+ """Multibox Loss
46
+ Args:
47
+ predictions (tuple): A tuple containing loc preds, conf preds,
48
+ and prior boxes from SSD net.
49
+ conf shape: torch.size(batch_size,num_priors,num_classes)
50
+ loc shape: torch.size(batch_size,num_priors,4)
51
+ priors shape: torch.size(num_priors,4)
52
+
53
+ ground_truth (tensor): Ground truth boxes and labels for a batch,
54
+ shape: [batch_size,num_objs,5] (last idx is the label).
55
+ """
56
+
57
+ loc_data, conf_data, landm_data = predictions
58
+ priors = priors
59
+ num = loc_data.size(0)
60
+ num_priors = (priors.size(0))
61
+
62
+ # match priors (default boxes) and ground truth boxes
63
+ loc_t = torch.Tensor(num, num_priors, 4)
64
+ landm_t = torch.Tensor(num, num_priors, 10)
65
+ conf_t = torch.LongTensor(num, num_priors)
66
+ for idx in range(num):
67
+ truths = targets[idx][:, :4].data
68
+ labels = targets[idx][:, -1].data
69
+ landms = targets[idx][:, 4:14].data
70
+ defaults = priors.data
71
+ match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
72
+ if GPU:
73
+ loc_t = loc_t.cuda()
74
+ conf_t = conf_t.cuda()
75
+ landm_t = landm_t.cuda()
76
+
77
+ zeros = torch.tensor(0).cuda()
78
+ # landm Loss (Smooth L1)
79
+ # Shape: [batch,num_priors,10]
80
+ pos1 = conf_t > zeros
81
+ num_pos_landm = pos1.long().sum(1, keepdim=True)
82
+ N1 = max(num_pos_landm.data.sum().float(), 1)
83
+ pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
84
+ landm_p = landm_data[pos_idx1].view(-1, 10)
85
+ landm_t = landm_t[pos_idx1].view(-1, 10)
86
+ loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
87
+
88
+
89
+ pos = conf_t != zeros
90
+ conf_t[pos] = 1
91
+
92
+ # Localization Loss (Smooth L1)
93
+ # Shape: [batch,num_priors,4]
94
+ pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
95
+ loc_p = loc_data[pos_idx].view(-1, 4)
96
+ loc_t = loc_t[pos_idx].view(-1, 4)
97
+ loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
98
+
99
+ # Compute max conf across batch for hard negative mining
100
+ batch_conf = conf_data.view(-1, self.num_classes)
101
+ loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
102
+
103
+ # Hard Negative Mining
104
+ loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
105
+ loss_c = loss_c.view(num, -1)
106
+ _, loss_idx = loss_c.sort(1, descending=True)
107
+ _, idx_rank = loss_idx.sort(1)
108
+ num_pos = pos.long().sum(1, keepdim=True)
109
+ num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
110
+ neg = idx_rank < num_neg.expand_as(idx_rank)
111
+
112
+ # Confidence Loss Including Positive and Negative Examples
113
+ pos_idx = pos.unsqueeze(2).expand_as(conf_data)
114
+ neg_idx = neg.unsqueeze(2).expand_as(conf_data)
115
+ conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
116
+ targets_weighted = conf_t[(pos+neg).gt(0)]
117
+ loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
118
+
119
+ # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
120
+ N = max(num_pos.data.sum().float(), 1)
121
+ loss_l /= N
122
+ loss_c /= N
123
+ loss_landm /= N1
124
+
125
+ return loss_l, loss_c, loss_landm