Spaces:
Runtime error
Runtime error
Commit ·
f3261a0
1
Parent(s): f5f446c
init
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- TPSMM/LICENSE +21 -0
- TPSMM/README.md +98 -0
- TPSMM/assets/source.png +0 -0
- TPSMM/assets/source1.png +0 -0
- TPSMM/assets/source2.jpg +0 -0
- TPSMM/augmentation.py +344 -0
- TPSMM/cog.yaml +40 -0
- TPSMM/config/mgif-256.yaml +75 -0
- TPSMM/config/taichi-256.yaml +134 -0
- TPSMM/config/ted-384.yaml +73 -0
- TPSMM/config/vox-256.yaml +74 -0
- TPSMM/demo.ipynb +0 -0
- TPSMM/demo.py +180 -0
- TPSMM/frames_dataset.py +173 -0
- TPSMM/logger.py +212 -0
- TPSMM/modules/avd_network.py +65 -0
- TPSMM/modules/bg_motion_predictor.py +24 -0
- TPSMM/modules/dense_motion.py +164 -0
- TPSMM/modules/inpainting_network.py +127 -0
- TPSMM/modules/keypoint_detector.py +27 -0
- TPSMM/modules/model.py +182 -0
- TPSMM/modules/util.py +349 -0
- TPSMM/pkgs/tpsmm.py +80 -0
- TPSMM/predict.py +125 -0
- TPSMM/pretrained/vox.pth.tar +3 -0
- TPSMM/reconstruction.py +69 -0
- TPSMM/requirements.txt +25 -0
- TPSMM/run.py +89 -0
- TPSMM/tmp.jpg +0 -0
- TPSMM/tmp.py +14 -0
- TPSMM/train.py +94 -0
- TPSMM/train_avd.py +91 -0
- app.py +122 -0
- assets/0.jpg +0 -0
- assets/1.jpg +0 -0
- assets/2.jpg +0 -0
- assets/3.jpg +0 -0
- requirements.txt +12 -0
- retinaface/change_batch_onnx.py +43 -0
- retinaface/convert_to_onnx.py +135 -0
- retinaface/data/__init__.py +3 -0
- retinaface/data/config.py +55 -0
- retinaface/data/data_augment.py +235 -0
- retinaface/data/wider_face.py +101 -0
- retinaface/detect.py +152 -0
- retinaface/detect_video_raw.py +66 -0
- retinaface/layers/__init__.py +2 -0
- retinaface/layers/functions/prior_box.py +33 -0
- retinaface/layers/modules/__init__.py +3 -0
- retinaface/layers/modules/multibox_loss.py +125 -0
TPSMM/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 yoyo-nb
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
TPSMM/README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [CVPR2022] Thin-Plate Spline Motion Model for Image Animation
|
| 2 |
+
|
| 3 |
+
[](LICENSE)
|
| 4 |
+

|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Source code of the CVPR'2022 paper "Thin-Plate Spline Motion Model for Image Animation"
|
| 8 |
+
|
| 9 |
+
[**Paper**](https://arxiv.org/abs/2203.14367) **|** [**Supp**](https://cloud.tsinghua.edu.cn/f/f7b8573bb5b04583949f/?dl=1)
|
| 10 |
+
|
| 11 |
+
### Example animation
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
**PS**: The paper trains the model for 100 epochs for a fair comparison. You can use more data and train for more epochs to get better performance.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
### Web demo for animation
|
| 20 |
+
- Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model)
|
| 21 |
+
- Try the web demo for animation here: [](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model)
|
| 22 |
+
- Google Colab: [](https://colab.research.google.com/drive/1DREfdpnaBhqISg0fuQlAAIwyGVn1loH_?usp=sharing)
|
| 23 |
+
|
| 24 |
+
### Pre-trained models
|
| 25 |
+
- [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/30ab8765da364fefa101/)
|
| 26 |
+
- [Google Drive](https://drive.google.com/drive/folders/1pNDo1ODQIb5HVObRtCmubqJikmR7VVLT?usp=sharing)
|
| 27 |
+
|
| 28 |
+
### Installation
|
| 29 |
+
|
| 30 |
+
We support ```python3```.(Recommended version is Python 3.9).
|
| 31 |
+
To install the dependencies run:
|
| 32 |
+
```bash
|
| 33 |
+
pip install -r requirements.txt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### YAML configs
|
| 38 |
+
|
| 39 |
+
There are several configuration files one for each `dataset` in the `config` folder named as ```config/dataset_name.yaml```.
|
| 40 |
+
|
| 41 |
+
See description of the parameters in the ```config/taichi-256.yaml```.
|
| 42 |
+
|
| 43 |
+
### Datasets
|
| 44 |
+
|
| 45 |
+
1) **MGif**. Follow [Monkey-Net](https://github.com/AliaksandrSiarohin/monkey-net).
|
| 46 |
+
|
| 47 |
+
2) **TaiChiHD** and **VoxCeleb**. Follow instructions from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
|
| 48 |
+
|
| 49 |
+
3) **TED-talks**. Follow instructions from [MRAA](https://github.com/snap-research/articulated-animation).
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
### Training
|
| 53 |
+
To train a model on specific dataset run:
|
| 54 |
+
```
|
| 55 |
+
CUDA_VISIBLE_DEVICES=0,1 python run.py --config config/dataset_name.yaml --device_ids 0,1
|
| 56 |
+
```
|
| 57 |
+
A log folder named after the timestamp will be created. Checkpoints, loss values, reconstruction results will be saved to this folder.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
#### Training AVD network
|
| 61 |
+
To train a model on specific dataset run:
|
| 62 |
+
```
|
| 63 |
+
CUDA_VISIBLE_DEVICES=0 python run.py --mode train_avd --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' --config config/dataset_name.yaml
|
| 64 |
+
```
|
| 65 |
+
Checkpoints, loss values, reconstruction results will be saved to `{checkpoint_folder}`.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
### Evaluation on video reconstruction
|
| 70 |
+
|
| 71 |
+
To evaluate the reconstruction performance run:
|
| 72 |
+
```
|
| 73 |
+
CUDA_VISIBLE_DEVICES=0 python run.py --mode reconstruction --config config/dataset_name.yaml --checkpoint '{checkpoint_folder}/checkpoint.pth.tar'
|
| 74 |
+
```
|
| 75 |
+
The `reconstruction` subfolder will be created in `{checkpoint_folder}`.
|
| 76 |
+
The generated video will be stored to this folder, also generated videos will be stored in ```png``` subfolder in loss-less '.png' format for evaluation.
|
| 77 |
+
To compute metrics, follow instructions from [pose-evaluation](https://github.com/AliaksandrSiarohin/pose-evaluation).
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
### Image animation demo
|
| 81 |
+
- notebook: `demo.ipynb`, edit the config cell and run for image animation.
|
| 82 |
+
- python:
|
| 83 |
+
```bash
|
| 84 |
+
CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-256.yaml --checkpoint checkpoints/vox.pth.tar --source_image ./source.jpg --driving_video ./driving.mp4
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
# Acknowledgments
|
| 88 |
+
The main code is based upon [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) and [MRAA](https://github.com/snap-research/articulated-animation)
|
| 89 |
+
|
| 90 |
+
Thanks for the excellent works!
|
| 91 |
+
|
| 92 |
+
And Thanks to:
|
| 93 |
+
|
| 94 |
+
- [@chenxwh](https://github.com/chenxwh): Add Web Demo & Docker environment [](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model)
|
| 95 |
+
|
| 96 |
+
- [@TalkUHulk](https://github.com/TalkUHulk): The C++/Python demo is provided in [Image-Animation-Turbo-Boost](https://github.com/TalkUHulk/Image-Animation-Turbo-Boost)
|
| 97 |
+
|
| 98 |
+
- [@AK391](https://github.com/AK391): Add huggingface web demo [](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model)
|
TPSMM/assets/source.png
ADDED
|
TPSMM/assets/source1.png
ADDED
|
TPSMM/assets/source2.jpg
ADDED
|
TPSMM/augmentation.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code from https://github.com/hassony2/torch_videovision
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numbers
|
| 6 |
+
|
| 7 |
+
import random
|
| 8 |
+
import numpy as np
|
| 9 |
+
import PIL
|
| 10 |
+
|
| 11 |
+
from skimage.transform import resize, rotate
|
| 12 |
+
import torchvision
|
| 13 |
+
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
from skimage import img_as_ubyte, img_as_float
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def crop_clip(clip, min_h, min_w, h, w):
|
| 20 |
+
if isinstance(clip[0], np.ndarray):
|
| 21 |
+
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
|
| 22 |
+
|
| 23 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
| 24 |
+
cropped = [
|
| 25 |
+
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
|
| 26 |
+
]
|
| 27 |
+
else:
|
| 28 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
| 29 |
+
'but got list of {0}'.format(type(clip[0])))
|
| 30 |
+
return cropped
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def pad_clip(clip, h, w):
|
| 34 |
+
im_h, im_w = clip[0].shape[:2]
|
| 35 |
+
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
|
| 36 |
+
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
|
| 37 |
+
|
| 38 |
+
return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def resize_clip(clip, size, interpolation='bilinear'):
|
| 42 |
+
if isinstance(clip[0], np.ndarray):
|
| 43 |
+
if isinstance(size, numbers.Number):
|
| 44 |
+
im_h, im_w, im_c = clip[0].shape
|
| 45 |
+
# Min spatial dim already matches minimal size
|
| 46 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
| 47 |
+
and im_h == size):
|
| 48 |
+
return clip
|
| 49 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
| 50 |
+
size = (new_w, new_h)
|
| 51 |
+
else:
|
| 52 |
+
size = size[1], size[0]
|
| 53 |
+
|
| 54 |
+
scaled = [
|
| 55 |
+
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
|
| 56 |
+
mode='constant', anti_aliasing=True) for img in clip
|
| 57 |
+
]
|
| 58 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
| 59 |
+
if isinstance(size, numbers.Number):
|
| 60 |
+
im_w, im_h = clip[0].size
|
| 61 |
+
# Min spatial dim already matches minimal size
|
| 62 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
| 63 |
+
and im_h == size):
|
| 64 |
+
return clip
|
| 65 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
| 66 |
+
size = (new_w, new_h)
|
| 67 |
+
else:
|
| 68 |
+
size = size[1], size[0]
|
| 69 |
+
if interpolation == 'bilinear':
|
| 70 |
+
pil_inter = PIL.Image.NEAREST
|
| 71 |
+
else:
|
| 72 |
+
pil_inter = PIL.Image.BILINEAR
|
| 73 |
+
scaled = [img.resize(size, pil_inter) for img in clip]
|
| 74 |
+
else:
|
| 75 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
| 76 |
+
'but got list of {0}'.format(type(clip[0])))
|
| 77 |
+
return scaled
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_resize_sizes(im_h, im_w, size):
|
| 81 |
+
if im_w < im_h:
|
| 82 |
+
ow = size
|
| 83 |
+
oh = int(size * im_h / im_w)
|
| 84 |
+
else:
|
| 85 |
+
oh = size
|
| 86 |
+
ow = int(size * im_w / im_h)
|
| 87 |
+
return oh, ow
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class RandomFlip(object):
|
| 91 |
+
def __init__(self, time_flip=False, horizontal_flip=False):
|
| 92 |
+
self.time_flip = time_flip
|
| 93 |
+
self.horizontal_flip = horizontal_flip
|
| 94 |
+
|
| 95 |
+
def __call__(self, clip):
|
| 96 |
+
if random.random() < 0.5 and self.time_flip:
|
| 97 |
+
return clip[::-1]
|
| 98 |
+
if random.random() < 0.5 and self.horizontal_flip:
|
| 99 |
+
return [np.fliplr(img) for img in clip]
|
| 100 |
+
|
| 101 |
+
return clip
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class RandomResize(object):
|
| 105 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
| 106 |
+
The larger the original image is, the more times it takes to
|
| 107 |
+
interpolate
|
| 108 |
+
Args:
|
| 109 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
| 110 |
+
defaults to nearest
|
| 111 |
+
size (tuple): (widht, height)
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
|
| 115 |
+
self.ratio = ratio
|
| 116 |
+
self.interpolation = interpolation
|
| 117 |
+
|
| 118 |
+
def __call__(self, clip):
|
| 119 |
+
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
| 120 |
+
|
| 121 |
+
if isinstance(clip[0], np.ndarray):
|
| 122 |
+
im_h, im_w, im_c = clip[0].shape
|
| 123 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
| 124 |
+
im_w, im_h = clip[0].size
|
| 125 |
+
|
| 126 |
+
new_w = int(im_w * scaling_factor)
|
| 127 |
+
new_h = int(im_h * scaling_factor)
|
| 128 |
+
new_size = (new_w, new_h)
|
| 129 |
+
resized = resize_clip(
|
| 130 |
+
clip, new_size, interpolation=self.interpolation)
|
| 131 |
+
|
| 132 |
+
return resized
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class RandomCrop(object):
|
| 136 |
+
"""Extract random crop at the same location for a list of videos
|
| 137 |
+
Args:
|
| 138 |
+
size (sequence or int): Desired output size for the
|
| 139 |
+
crop in format (h, w)
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, size):
|
| 143 |
+
if isinstance(size, numbers.Number):
|
| 144 |
+
size = (size, size)
|
| 145 |
+
|
| 146 |
+
self.size = size
|
| 147 |
+
|
| 148 |
+
def __call__(self, clip):
|
| 149 |
+
"""
|
| 150 |
+
Args:
|
| 151 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
| 152 |
+
in format (h, w, c) in numpy.ndarray
|
| 153 |
+
Returns:
|
| 154 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
| 155 |
+
"""
|
| 156 |
+
h, w = self.size
|
| 157 |
+
if isinstance(clip[0], np.ndarray):
|
| 158 |
+
im_h, im_w, im_c = clip[0].shape
|
| 159 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
| 160 |
+
im_w, im_h = clip[0].size
|
| 161 |
+
else:
|
| 162 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
| 163 |
+
'but got list of {0}'.format(type(clip[0])))
|
| 164 |
+
|
| 165 |
+
clip = pad_clip(clip, h, w)
|
| 166 |
+
im_h, im_w = clip.shape[1:3]
|
| 167 |
+
x1 = 0 if h == im_h else random.randint(0, im_w - w)
|
| 168 |
+
y1 = 0 if w == im_w else random.randint(0, im_h - h)
|
| 169 |
+
cropped = crop_clip(clip, y1, x1, h, w)
|
| 170 |
+
|
| 171 |
+
return cropped
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class RandomRotation(object):
|
| 175 |
+
"""Rotate entire clip randomly by a random angle within
|
| 176 |
+
given bounds
|
| 177 |
+
Args:
|
| 178 |
+
degrees (sequence or int): Range of degrees to select from
|
| 179 |
+
If degrees is a number instead of sequence like (min, max),
|
| 180 |
+
the range of degrees, will be (-degrees, +degrees).
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, degrees):
|
| 184 |
+
if isinstance(degrees, numbers.Number):
|
| 185 |
+
if degrees < 0:
|
| 186 |
+
raise ValueError('If degrees is a single number,'
|
| 187 |
+
'must be positive')
|
| 188 |
+
degrees = (-degrees, degrees)
|
| 189 |
+
else:
|
| 190 |
+
if len(degrees) != 2:
|
| 191 |
+
raise ValueError('If degrees is a sequence,'
|
| 192 |
+
'it must be of len 2.')
|
| 193 |
+
|
| 194 |
+
self.degrees = degrees
|
| 195 |
+
|
| 196 |
+
def __call__(self, clip):
|
| 197 |
+
"""
|
| 198 |
+
Args:
|
| 199 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
| 200 |
+
in format (h, w, c) in numpy.ndarray
|
| 201 |
+
Returns:
|
| 202 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
| 203 |
+
"""
|
| 204 |
+
angle = random.uniform(self.degrees[0], self.degrees[1])
|
| 205 |
+
if isinstance(clip[0], np.ndarray):
|
| 206 |
+
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
|
| 207 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
| 208 |
+
rotated = [img.rotate(angle) for img in clip]
|
| 209 |
+
else:
|
| 210 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
| 211 |
+
'but got list of {0}'.format(type(clip[0])))
|
| 212 |
+
|
| 213 |
+
return rotated
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class ColorJitter(object):
|
| 217 |
+
"""Randomly change the brightness, contrast and saturation and hue of the clip
|
| 218 |
+
Args:
|
| 219 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
| 220 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
| 221 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
| 222 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
| 223 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
| 224 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
| 225 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
| 226 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
| 230 |
+
self.brightness = brightness
|
| 231 |
+
self.contrast = contrast
|
| 232 |
+
self.saturation = saturation
|
| 233 |
+
self.hue = hue
|
| 234 |
+
|
| 235 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
| 236 |
+
if brightness > 0:
|
| 237 |
+
brightness_factor = random.uniform(
|
| 238 |
+
max(0, 1 - brightness), 1 + brightness)
|
| 239 |
+
else:
|
| 240 |
+
brightness_factor = None
|
| 241 |
+
|
| 242 |
+
if contrast > 0:
|
| 243 |
+
contrast_factor = random.uniform(
|
| 244 |
+
max(0, 1 - contrast), 1 + contrast)
|
| 245 |
+
else:
|
| 246 |
+
contrast_factor = None
|
| 247 |
+
|
| 248 |
+
if saturation > 0:
|
| 249 |
+
saturation_factor = random.uniform(
|
| 250 |
+
max(0, 1 - saturation), 1 + saturation)
|
| 251 |
+
else:
|
| 252 |
+
saturation_factor = None
|
| 253 |
+
|
| 254 |
+
if hue > 0:
|
| 255 |
+
hue_factor = random.uniform(-hue, hue)
|
| 256 |
+
else:
|
| 257 |
+
hue_factor = None
|
| 258 |
+
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
| 259 |
+
|
| 260 |
+
def __call__(self, clip):
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
clip (list): list of PIL.Image
|
| 264 |
+
Returns:
|
| 265 |
+
list PIL.Image : list of transformed PIL.Image
|
| 266 |
+
"""
|
| 267 |
+
if isinstance(clip[0], np.ndarray):
|
| 268 |
+
brightness, contrast, saturation, hue = self.get_params(
|
| 269 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
| 270 |
+
|
| 271 |
+
# Create img transform function sequence
|
| 272 |
+
img_transforms = []
|
| 273 |
+
if brightness is not None:
|
| 274 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
| 275 |
+
if saturation is not None:
|
| 276 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
| 277 |
+
if hue is not None:
|
| 278 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
| 279 |
+
if contrast is not None:
|
| 280 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
| 281 |
+
random.shuffle(img_transforms)
|
| 282 |
+
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
|
| 283 |
+
img_as_float]
|
| 284 |
+
|
| 285 |
+
with warnings.catch_warnings():
|
| 286 |
+
warnings.simplefilter("ignore")
|
| 287 |
+
jittered_clip = []
|
| 288 |
+
for img in clip:
|
| 289 |
+
jittered_img = img
|
| 290 |
+
for func in img_transforms:
|
| 291 |
+
jittered_img = func(jittered_img)
|
| 292 |
+
jittered_clip.append(jittered_img.astype('float32'))
|
| 293 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
| 294 |
+
brightness, contrast, saturation, hue = self.get_params(
|
| 295 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
| 296 |
+
|
| 297 |
+
# Create img transform function sequence
|
| 298 |
+
img_transforms = []
|
| 299 |
+
if brightness is not None:
|
| 300 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
| 301 |
+
if saturation is not None:
|
| 302 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
| 303 |
+
if hue is not None:
|
| 304 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
| 305 |
+
if contrast is not None:
|
| 306 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
| 307 |
+
random.shuffle(img_transforms)
|
| 308 |
+
|
| 309 |
+
# Apply to all videos
|
| 310 |
+
jittered_clip = []
|
| 311 |
+
for img in clip:
|
| 312 |
+
for func in img_transforms:
|
| 313 |
+
jittered_img = func(img)
|
| 314 |
+
jittered_clip.append(jittered_img)
|
| 315 |
+
|
| 316 |
+
else:
|
| 317 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
| 318 |
+
'but got list of {0}'.format(type(clip[0])))
|
| 319 |
+
return jittered_clip
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class AllAugmentationTransform:
|
| 323 |
+
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
|
| 324 |
+
self.transforms = []
|
| 325 |
+
|
| 326 |
+
if flip_param is not None:
|
| 327 |
+
self.transforms.append(RandomFlip(**flip_param))
|
| 328 |
+
|
| 329 |
+
if rotation_param is not None:
|
| 330 |
+
self.transforms.append(RandomRotation(**rotation_param))
|
| 331 |
+
|
| 332 |
+
if resize_param is not None:
|
| 333 |
+
self.transforms.append(RandomResize(**resize_param))
|
| 334 |
+
|
| 335 |
+
if crop_param is not None:
|
| 336 |
+
self.transforms.append(RandomCrop(**crop_param))
|
| 337 |
+
|
| 338 |
+
if jitter_param is not None:
|
| 339 |
+
self.transforms.append(ColorJitter(**jitter_param))
|
| 340 |
+
|
| 341 |
+
def __call__(self, clip):
|
| 342 |
+
for t in self.transforms:
|
| 343 |
+
clip = t(clip)
|
| 344 |
+
return clip
|
TPSMM/cog.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build:
|
| 2 |
+
cuda: "11.0"
|
| 3 |
+
gpu: true
|
| 4 |
+
python_version: "3.8"
|
| 5 |
+
system_packages:
|
| 6 |
+
- "libgl1-mesa-glx"
|
| 7 |
+
- "libglib2.0-0"
|
| 8 |
+
- "ninja-build"
|
| 9 |
+
python_packages:
|
| 10 |
+
- "ipython==7.21.0"
|
| 11 |
+
- "torch==1.10.1"
|
| 12 |
+
- "torchvision==0.11.2"
|
| 13 |
+
- "cffi==1.14.6"
|
| 14 |
+
- "cycler==0.10.0"
|
| 15 |
+
- "decorator==5.1.0"
|
| 16 |
+
- "face-alignment==1.3.5"
|
| 17 |
+
- "imageio==2.9.0"
|
| 18 |
+
- "imageio-ffmpeg==0.4.5"
|
| 19 |
+
- "kiwisolver==1.3.2"
|
| 20 |
+
- "matplotlib==3.4.3"
|
| 21 |
+
- "networkx==2.6.3"
|
| 22 |
+
- "numpy==1.20.3"
|
| 23 |
+
- "pandas==1.3.3"
|
| 24 |
+
- "Pillow==8.3.2"
|
| 25 |
+
- "pycparser==2.20"
|
| 26 |
+
- "pyparsing==2.4.7"
|
| 27 |
+
- "python-dateutil==2.8.2"
|
| 28 |
+
- "pytz==2021.1"
|
| 29 |
+
- "PyWavelets==1.1.1"
|
| 30 |
+
- "PyYAML==5.4.1"
|
| 31 |
+
- "scikit-image==0.18.3"
|
| 32 |
+
- "scikit-learn==1.0"
|
| 33 |
+
- "scipy==1.7.1"
|
| 34 |
+
- "six==1.16.0"
|
| 35 |
+
- "tqdm==4.62.3"
|
| 36 |
+
- "cmake==3.21.3"
|
| 37 |
+
run:
|
| 38 |
+
- pip install dlib
|
| 39 |
+
|
| 40 |
+
predict: "predict.py:Predictor"
|
TPSMM/config/mgif-256.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_params:
|
| 2 |
+
root_dir: ../moving-gif
|
| 3 |
+
frame_shape: null
|
| 4 |
+
id_sampling: False
|
| 5 |
+
augmentation_params:
|
| 6 |
+
flip_param:
|
| 7 |
+
horizontal_flip: True
|
| 8 |
+
time_flip: True
|
| 9 |
+
crop_param:
|
| 10 |
+
size: [256, 256]
|
| 11 |
+
resize_param:
|
| 12 |
+
ratio: [0.9, 1.1]
|
| 13 |
+
jitter_param:
|
| 14 |
+
hue: 0.5
|
| 15 |
+
|
| 16 |
+
model_params:
|
| 17 |
+
common_params:
|
| 18 |
+
num_tps: 10
|
| 19 |
+
num_channels: 3
|
| 20 |
+
bg: False
|
| 21 |
+
multi_mask: True
|
| 22 |
+
generator_params:
|
| 23 |
+
block_expansion: 64
|
| 24 |
+
max_features: 512
|
| 25 |
+
num_down_blocks: 3
|
| 26 |
+
dense_motion_params:
|
| 27 |
+
block_expansion: 64
|
| 28 |
+
max_features: 1024
|
| 29 |
+
num_blocks: 5
|
| 30 |
+
scale_factor: 0.25
|
| 31 |
+
avd_network_params:
|
| 32 |
+
id_bottle_size: 128
|
| 33 |
+
pose_bottle_size: 128
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
train_params:
|
| 37 |
+
num_epochs: 100
|
| 38 |
+
num_repeats: 50
|
| 39 |
+
epoch_milestones: [70, 90]
|
| 40 |
+
lr_generator: 2.0e-4
|
| 41 |
+
batch_size: 28
|
| 42 |
+
scales: [1, 0.5, 0.25, 0.125]
|
| 43 |
+
dataloader_workers: 12
|
| 44 |
+
checkpoint_freq: 50
|
| 45 |
+
dropout_epoch: 35
|
| 46 |
+
dropout_maxp: 0.5
|
| 47 |
+
dropout_startp: 0.2
|
| 48 |
+
dropout_inc_epoch: 10
|
| 49 |
+
bg_start: 0
|
| 50 |
+
transform_params:
|
| 51 |
+
sigma_affine: 0.05
|
| 52 |
+
sigma_tps: 0.005
|
| 53 |
+
points_tps: 5
|
| 54 |
+
loss_weights:
|
| 55 |
+
perceptual: [10, 10, 10, 10, 10]
|
| 56 |
+
equivariance_value: 10
|
| 57 |
+
warp_loss: 10
|
| 58 |
+
bg: 10
|
| 59 |
+
|
| 60 |
+
train_avd_params:
|
| 61 |
+
num_epochs: 100
|
| 62 |
+
num_repeats: 50
|
| 63 |
+
batch_size: 256
|
| 64 |
+
dataloader_workers: 24
|
| 65 |
+
checkpoint_freq: 10
|
| 66 |
+
epoch_milestones: [70, 90]
|
| 67 |
+
lr: 1.0e-3
|
| 68 |
+
lambda_shift: 1
|
| 69 |
+
lambda_affine: 1
|
| 70 |
+
random_scale: 0.25
|
| 71 |
+
|
| 72 |
+
visualizer_params:
|
| 73 |
+
kp_size: 5
|
| 74 |
+
draw_border: True
|
| 75 |
+
colormap: 'gist_rainbow'
|
TPSMM/config/taichi-256.yaml
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset parameters
|
| 2 |
+
# Each dataset should contain 2 folders train and test
|
| 3 |
+
# Each video can be represented as:
|
| 4 |
+
# - an image of concatenated frames
|
| 5 |
+
# - '.mp4' or '.gif'
|
| 6 |
+
# - folder with all frames from a specific video
|
| 7 |
+
# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
|
| 8 |
+
# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
|
| 9 |
+
# video id.
|
| 10 |
+
dataset_params:
|
| 11 |
+
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
|
| 12 |
+
root_dir: ../taichi
|
| 13 |
+
# Image shape, needed for staked .png format.
|
| 14 |
+
frame_shape: null
|
| 15 |
+
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
|
| 16 |
+
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
|
| 17 |
+
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
|
| 18 |
+
id_sampling: True
|
| 19 |
+
# Augmentation parameters see augmentation.py for all posible augmentations
|
| 20 |
+
augmentation_params:
|
| 21 |
+
flip_param:
|
| 22 |
+
horizontal_flip: True
|
| 23 |
+
time_flip: True
|
| 24 |
+
jitter_param:
|
| 25 |
+
brightness: 0.1
|
| 26 |
+
contrast: 0.1
|
| 27 |
+
saturation: 0.1
|
| 28 |
+
hue: 0.1
|
| 29 |
+
|
| 30 |
+
# Defines model architecture
|
| 31 |
+
model_params:
|
| 32 |
+
common_params:
|
| 33 |
+
# Number of TPS transformation
|
| 34 |
+
num_tps: 10
|
| 35 |
+
# Number of channels per image
|
| 36 |
+
num_channels: 3
|
| 37 |
+
# Whether to estimate affine background transformation
|
| 38 |
+
bg: True
|
| 39 |
+
# Whether to estimate the multi-resolution occlusion masks
|
| 40 |
+
multi_mask: True
|
| 41 |
+
generator_params:
|
| 42 |
+
# Number of features mutliplier
|
| 43 |
+
block_expansion: 64
|
| 44 |
+
# Maximum allowed number of features
|
| 45 |
+
max_features: 512
|
| 46 |
+
# Number of downsampling blocks and Upsampling blocks.
|
| 47 |
+
num_down_blocks: 3
|
| 48 |
+
dense_motion_params:
|
| 49 |
+
# Number of features mutliplier
|
| 50 |
+
block_expansion: 64
|
| 51 |
+
# Maximum allowed number of features
|
| 52 |
+
max_features: 1024
|
| 53 |
+
# Number of block in Unet.
|
| 54 |
+
num_blocks: 5
|
| 55 |
+
# Optical flow is predicted on smaller images for better performance,
|
| 56 |
+
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
| 57 |
+
scale_factor: 0.25
|
| 58 |
+
avd_network_params:
|
| 59 |
+
# Bottleneck for identity branch
|
| 60 |
+
id_bottle_size: 128
|
| 61 |
+
# Bottleneck for pose branch
|
| 62 |
+
pose_bottle_size: 128
|
| 63 |
+
|
| 64 |
+
# Parameters of training
|
| 65 |
+
train_params:
|
| 66 |
+
# Number of training epochs
|
| 67 |
+
num_epochs: 100
|
| 68 |
+
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
| 69 |
+
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
|
| 70 |
+
num_repeats: 150
|
| 71 |
+
# Drop learning rate by 10 times after this epochs
|
| 72 |
+
epoch_milestones: [70, 90]
|
| 73 |
+
# Initial learing rate for all modules
|
| 74 |
+
lr_generator: 2.0e-4
|
| 75 |
+
batch_size: 28
|
| 76 |
+
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
|
| 77 |
+
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
|
| 78 |
+
scales: [1, 0.5, 0.25, 0.125]
|
| 79 |
+
# Dataset preprocessing cpu workers
|
| 80 |
+
dataloader_workers: 12
|
| 81 |
+
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
| 82 |
+
checkpoint_freq: 50
|
| 83 |
+
# Parameters of dropout
|
| 84 |
+
# The first dropout_epoch training uses dropout operation
|
| 85 |
+
dropout_epoch: 35
|
| 86 |
+
# The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs
|
| 87 |
+
dropout_maxp: 0.7
|
| 88 |
+
dropout_startp: 0.0
|
| 89 |
+
dropout_inc_epoch: 10
|
| 90 |
+
# Estimate affine background transformation from the bg_start epoch.
|
| 91 |
+
bg_start: 0
|
| 92 |
+
# Parameters of random TPS transformation for equivariance loss
|
| 93 |
+
transform_params:
|
| 94 |
+
# Sigma for affine part
|
| 95 |
+
sigma_affine: 0.05
|
| 96 |
+
# Sigma for deformation part
|
| 97 |
+
sigma_tps: 0.005
|
| 98 |
+
# Number of point in the deformation grid
|
| 99 |
+
points_tps: 5
|
| 100 |
+
loss_weights:
|
| 101 |
+
# Weights for perceptual loss.
|
| 102 |
+
perceptual: [10, 10, 10, 10, 10]
|
| 103 |
+
# Weights for value equivariance.
|
| 104 |
+
equivariance_value: 10
|
| 105 |
+
# Weights for warp loss.
|
| 106 |
+
warp_loss: 10
|
| 107 |
+
# Weights for bg loss.
|
| 108 |
+
bg: 10
|
| 109 |
+
|
| 110 |
+
# Parameters of training (animation-via-disentanglement)
|
| 111 |
+
train_avd_params:
|
| 112 |
+
# Number of training epochs, visualization is produced after each epoch.
|
| 113 |
+
num_epochs: 100
|
| 114 |
+
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
| 115 |
+
# Thus effectively with num_repeats=100 each epoch is 100 times larger.
|
| 116 |
+
num_repeats: 150
|
| 117 |
+
# Batch size.
|
| 118 |
+
batch_size: 256
|
| 119 |
+
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
| 120 |
+
checkpoint_freq: 10
|
| 121 |
+
# Dataset preprocessing cpu workers
|
| 122 |
+
dataloader_workers: 24
|
| 123 |
+
# Drop learning rate 10 times after this epochs
|
| 124 |
+
epoch_milestones: [70, 90]
|
| 125 |
+
# Initial learning rate
|
| 126 |
+
lr: 1.0e-3
|
| 127 |
+
# Weights for equivariance loss.
|
| 128 |
+
lambda_shift: 1
|
| 129 |
+
random_scale: 0.25
|
| 130 |
+
|
| 131 |
+
visualizer_params:
|
| 132 |
+
kp_size: 5
|
| 133 |
+
draw_border: True
|
| 134 |
+
colormap: 'gist_rainbow'
|
TPSMM/config/ted-384.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_params:
|
| 2 |
+
root_dir: ../TED384-v2
|
| 3 |
+
frame_shape: null
|
| 4 |
+
id_sampling: True
|
| 5 |
+
augmentation_params:
|
| 6 |
+
flip_param:
|
| 7 |
+
horizontal_flip: True
|
| 8 |
+
time_flip: True
|
| 9 |
+
jitter_param:
|
| 10 |
+
brightness: 0.1
|
| 11 |
+
contrast: 0.1
|
| 12 |
+
saturation: 0.1
|
| 13 |
+
hue: 0.1
|
| 14 |
+
|
| 15 |
+
model_params:
|
| 16 |
+
common_params:
|
| 17 |
+
num_tps: 10
|
| 18 |
+
num_channels: 3
|
| 19 |
+
bg: True
|
| 20 |
+
multi_mask: True
|
| 21 |
+
generator_params:
|
| 22 |
+
block_expansion: 64
|
| 23 |
+
max_features: 512
|
| 24 |
+
num_down_blocks: 3
|
| 25 |
+
dense_motion_params:
|
| 26 |
+
block_expansion: 64
|
| 27 |
+
max_features: 1024
|
| 28 |
+
num_blocks: 5
|
| 29 |
+
scale_factor: 0.25
|
| 30 |
+
avd_network_params:
|
| 31 |
+
id_bottle_size: 128
|
| 32 |
+
pose_bottle_size: 128
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
train_params:
|
| 36 |
+
num_epochs: 100
|
| 37 |
+
num_repeats: 150
|
| 38 |
+
epoch_milestones: [70, 90]
|
| 39 |
+
lr_generator: 2.0e-4
|
| 40 |
+
batch_size: 12
|
| 41 |
+
scales: [1, 0.5, 0.25, 0.125]
|
| 42 |
+
dataloader_workers: 6
|
| 43 |
+
checkpoint_freq: 50
|
| 44 |
+
dropout_epoch: 35
|
| 45 |
+
dropout_maxp: 0.5
|
| 46 |
+
dropout_startp: 0.0
|
| 47 |
+
dropout_inc_epoch: 10
|
| 48 |
+
bg_start: 0
|
| 49 |
+
transform_params:
|
| 50 |
+
sigma_affine: 0.05
|
| 51 |
+
sigma_tps: 0.005
|
| 52 |
+
points_tps: 5
|
| 53 |
+
loss_weights:
|
| 54 |
+
perceptual: [10, 10, 10, 10, 10]
|
| 55 |
+
equivariance_value: 10
|
| 56 |
+
warp_loss: 10
|
| 57 |
+
bg: 10
|
| 58 |
+
|
| 59 |
+
train_avd_params:
|
| 60 |
+
num_epochs: 30
|
| 61 |
+
num_repeats: 500
|
| 62 |
+
batch_size: 256
|
| 63 |
+
dataloader_workers: 24
|
| 64 |
+
checkpoint_freq: 10
|
| 65 |
+
epoch_milestones: [20, 25]
|
| 66 |
+
lr: 1.0e-3
|
| 67 |
+
lambda_shift: 1
|
| 68 |
+
random_scale: 0.25
|
| 69 |
+
|
| 70 |
+
visualizer_params:
|
| 71 |
+
kp_size: 5
|
| 72 |
+
draw_border: True
|
| 73 |
+
colormap: 'gist_rainbow'
|
TPSMM/config/vox-256.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_params:
|
| 2 |
+
root_dir: ../vox
|
| 3 |
+
frame_shape: null
|
| 4 |
+
id_sampling: True
|
| 5 |
+
augmentation_params:
|
| 6 |
+
flip_param:
|
| 7 |
+
horizontal_flip: True
|
| 8 |
+
time_flip: True
|
| 9 |
+
jitter_param:
|
| 10 |
+
brightness: 0.1
|
| 11 |
+
contrast: 0.1
|
| 12 |
+
saturation: 0.1
|
| 13 |
+
hue: 0.1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
model_params:
|
| 17 |
+
common_params:
|
| 18 |
+
num_tps: 10
|
| 19 |
+
num_channels: 3
|
| 20 |
+
bg: True
|
| 21 |
+
multi_mask: True
|
| 22 |
+
generator_params:
|
| 23 |
+
block_expansion: 64
|
| 24 |
+
max_features: 512
|
| 25 |
+
num_down_blocks: 3
|
| 26 |
+
dense_motion_params:
|
| 27 |
+
block_expansion: 64
|
| 28 |
+
max_features: 1024
|
| 29 |
+
num_blocks: 5
|
| 30 |
+
scale_factor: 0.25
|
| 31 |
+
avd_network_params:
|
| 32 |
+
id_bottle_size: 128
|
| 33 |
+
pose_bottle_size: 128
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
train_params:
|
| 37 |
+
num_epochs: 100
|
| 38 |
+
num_repeats: 75
|
| 39 |
+
epoch_milestones: [70, 90]
|
| 40 |
+
lr_generator: 2.0e-4
|
| 41 |
+
batch_size: 28
|
| 42 |
+
scales: [1, 0.5, 0.25, 0.125]
|
| 43 |
+
dataloader_workers: 12
|
| 44 |
+
checkpoint_freq: 50
|
| 45 |
+
dropout_epoch: 35
|
| 46 |
+
dropout_maxp: 0.3
|
| 47 |
+
dropout_startp: 0.1
|
| 48 |
+
dropout_inc_epoch: 10
|
| 49 |
+
bg_start: 10
|
| 50 |
+
transform_params:
|
| 51 |
+
sigma_affine: 0.05
|
| 52 |
+
sigma_tps: 0.005
|
| 53 |
+
points_tps: 5
|
| 54 |
+
loss_weights:
|
| 55 |
+
perceptual: [10, 10, 10, 10, 10]
|
| 56 |
+
equivariance_value: 10
|
| 57 |
+
warp_loss: 10
|
| 58 |
+
bg: 10
|
| 59 |
+
|
| 60 |
+
train_avd_params:
|
| 61 |
+
num_epochs: 200
|
| 62 |
+
num_repeats: 300
|
| 63 |
+
batch_size: 256
|
| 64 |
+
dataloader_workers: 24
|
| 65 |
+
checkpoint_freq: 50
|
| 66 |
+
epoch_milestones: [140, 180]
|
| 67 |
+
lr: 1.0e-3
|
| 68 |
+
lambda_shift: 1
|
| 69 |
+
random_scale: 0.25
|
| 70 |
+
|
| 71 |
+
visualizer_params:
|
| 72 |
+
kp_size: 5
|
| 73 |
+
draw_border: True
|
| 74 |
+
colormap: 'gist_rainbow'
|
TPSMM/demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
TPSMM/demo.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use('Agg')
|
| 3 |
+
import sys
|
| 4 |
+
import yaml
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from scipy.spatial import ConvexHull
|
| 8 |
+
import numpy as np
|
| 9 |
+
import imageio
|
| 10 |
+
from skimage.transform import resize
|
| 11 |
+
from skimage import img_as_ubyte
|
| 12 |
+
import torch
|
| 13 |
+
from modules.inpainting_network import InpaintingNetwork
|
| 14 |
+
from modules.keypoint_detector import KPDetector
|
| 15 |
+
from modules.dense_motion import DenseMotionNetwork
|
| 16 |
+
from modules.avd_network import AVDNetwork
|
| 17 |
+
|
| 18 |
+
if sys.version_info[0] < 3:
|
| 19 |
+
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
|
| 20 |
+
|
| 21 |
+
def relative_kp(kp_source, kp_driving, kp_driving_initial):
|
| 22 |
+
|
| 23 |
+
source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
|
| 24 |
+
driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
|
| 25 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
| 26 |
+
|
| 27 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
| 28 |
+
|
| 29 |
+
kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
|
| 30 |
+
kp_value_diff *= adapt_movement_scale
|
| 31 |
+
kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
|
| 32 |
+
|
| 33 |
+
return kp_new
|
| 34 |
+
|
| 35 |
+
def load_checkpoints(config_path, checkpoint_path, device):
|
| 36 |
+
with open(config_path) as f:
|
| 37 |
+
config = yaml.full_load(f)
|
| 38 |
+
|
| 39 |
+
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
|
| 40 |
+
**config['model_params']['common_params'])
|
| 41 |
+
kp_detector = KPDetector(**config['model_params']['common_params'])
|
| 42 |
+
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
|
| 43 |
+
**config['model_params']['dense_motion_params'])
|
| 44 |
+
avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
|
| 45 |
+
**config['model_params']['avd_network_params'])
|
| 46 |
+
kp_detector.to(device)
|
| 47 |
+
dense_motion_network.to(device)
|
| 48 |
+
inpainting.to(device)
|
| 49 |
+
avd_network.to(device)
|
| 50 |
+
|
| 51 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 52 |
+
|
| 53 |
+
inpainting.load_state_dict(checkpoint['inpainting_network'])
|
| 54 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
| 55 |
+
dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
|
| 56 |
+
if 'avd_network' in checkpoint:
|
| 57 |
+
avd_network.load_state_dict(checkpoint['avd_network'])
|
| 58 |
+
|
| 59 |
+
inpainting.eval()
|
| 60 |
+
kp_detector.eval()
|
| 61 |
+
dense_motion_network.eval()
|
| 62 |
+
avd_network.eval()
|
| 63 |
+
|
| 64 |
+
return inpainting, kp_detector, dense_motion_network, avd_network
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
|
| 68 |
+
assert mode in ['standard', 'relative', 'avd']
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
predictions = []
|
| 71 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
| 72 |
+
source = source.to(device)
|
| 73 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
|
| 74 |
+
kp_source = kp_detector(source)
|
| 75 |
+
kp_driving_initial = kp_detector(driving[:, :, 0])
|
| 76 |
+
|
| 77 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
| 78 |
+
driving_frame = driving[:, :, frame_idx]
|
| 79 |
+
driving_frame = driving_frame.to(device)
|
| 80 |
+
kp_driving = kp_detector(driving_frame)
|
| 81 |
+
if mode == 'standard':
|
| 82 |
+
kp_norm = kp_driving
|
| 83 |
+
elif mode=='relative':
|
| 84 |
+
kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
|
| 85 |
+
kp_driving_initial=kp_driving_initial)
|
| 86 |
+
elif mode == 'avd':
|
| 87 |
+
kp_norm = avd_network(kp_source, kp_driving)
|
| 88 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
|
| 89 |
+
kp_source=kp_source, bg_param = None,
|
| 90 |
+
dropout_flag = False)
|
| 91 |
+
out = inpainting_network(source, dense_motion)
|
| 92 |
+
|
| 93 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
| 94 |
+
return predictions
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def find_best_frame(source, driving, cpu):
|
| 98 |
+
import face_alignment
|
| 99 |
+
|
| 100 |
+
def normalize_kp(kp):
|
| 101 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
| 102 |
+
area = ConvexHull(kp[:, :2]).volume
|
| 103 |
+
area = np.sqrt(area)
|
| 104 |
+
kp[:, :2] = kp[:, :2] / area
|
| 105 |
+
return kp
|
| 106 |
+
|
| 107 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
| 108 |
+
device= 'cpu' if cpu else 'cuda')
|
| 109 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
| 110 |
+
kp_source = normalize_kp(kp_source)
|
| 111 |
+
norm = float('inf')
|
| 112 |
+
frame_num = 0
|
| 113 |
+
for i, image in tqdm(enumerate(driving)):
|
| 114 |
+
try:
|
| 115 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
| 116 |
+
kp_driving = normalize_kp(kp_driving)
|
| 117 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
| 118 |
+
if new_norm < norm:
|
| 119 |
+
norm = new_norm
|
| 120 |
+
frame_num = i
|
| 121 |
+
except:
|
| 122 |
+
pass
|
| 123 |
+
return frame_num
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
parser = ArgumentParser()
|
| 128 |
+
parser.add_argument("--config", required=True, help="path to config")
|
| 129 |
+
parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore")
|
| 130 |
+
|
| 131 |
+
parser.add_argument("--source_image", default='./assets/source.png', help="path to source image")
|
| 132 |
+
parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video")
|
| 133 |
+
parser.add_argument("--result_video", default='./result.mp4', help="path to output")
|
| 134 |
+
|
| 135 |
+
parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))),
|
| 136 |
+
help='Shape of image, that the model was trained on.')
|
| 137 |
+
|
| 138 |
+
parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
|
| 139 |
+
|
| 140 |
+
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
|
| 141 |
+
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
|
| 142 |
+
|
| 143 |
+
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
|
| 144 |
+
|
| 145 |
+
opt = parser.parse_args()
|
| 146 |
+
|
| 147 |
+
source_image = imageio.imread(opt.source_image)
|
| 148 |
+
reader = imageio.get_reader(opt.driving_video)
|
| 149 |
+
fps = reader.get_meta_data()['fps']
|
| 150 |
+
driving_video = []
|
| 151 |
+
try:
|
| 152 |
+
for im in reader:
|
| 153 |
+
driving_video.append(im)
|
| 154 |
+
except RuntimeError:
|
| 155 |
+
pass
|
| 156 |
+
reader.close()
|
| 157 |
+
|
| 158 |
+
if opt.cpu:
|
| 159 |
+
device = torch.device('cpu')
|
| 160 |
+
else:
|
| 161 |
+
device = torch.device('cuda')
|
| 162 |
+
|
| 163 |
+
source_image = resize(source_image, opt.img_shape)[..., :3]
|
| 164 |
+
driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video]
|
| 165 |
+
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
|
| 166 |
+
|
| 167 |
+
if opt.find_best_frame:
|
| 168 |
+
i = find_best_frame(source_image, driving_video, opt.cpu)
|
| 169 |
+
print ("Best frame: " + str(i))
|
| 170 |
+
driving_forward = driving_video[i:]
|
| 171 |
+
driving_backward = driving_video[:(i+1)][::-1]
|
| 172 |
+
predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
|
| 173 |
+
predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
|
| 174 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
| 175 |
+
else:
|
| 176 |
+
predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
|
| 177 |
+
|
| 178 |
+
predictions =[img_as_ubyte(frame) for frame in predictions]
|
| 179 |
+
imageio.mimsave(opt.result_video, predictions, fps=fps)
|
| 180 |
+
|
TPSMM/frames_dataset.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from skimage import io, img_as_float32
|
| 3 |
+
from skimage.color import gray2rgb
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from imageio import mimread
|
| 6 |
+
from skimage.transform import resize
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from augmentation import AllAugmentationTransform
|
| 10 |
+
import glob
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def read_video(name, frame_shape):
|
| 15 |
+
"""
|
| 16 |
+
Read video which can be:
|
| 17 |
+
- an image of concatenated frames
|
| 18 |
+
- '.mp4' and'.gif'
|
| 19 |
+
- folder with videos
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
if os.path.isdir(name):
|
| 23 |
+
frames = sorted(os.listdir(name))
|
| 24 |
+
num_frames = len(frames)
|
| 25 |
+
video_array = np.array(
|
| 26 |
+
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
|
| 27 |
+
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
|
| 28 |
+
image = io.imread(name)
|
| 29 |
+
|
| 30 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
| 31 |
+
image = gray2rgb(image)
|
| 32 |
+
|
| 33 |
+
if image.shape[2] == 4:
|
| 34 |
+
image = image[..., :3]
|
| 35 |
+
|
| 36 |
+
image = img_as_float32(image)
|
| 37 |
+
|
| 38 |
+
video_array = np.moveaxis(image, 1, 0)
|
| 39 |
+
|
| 40 |
+
video_array = video_array.reshape((-1,) + frame_shape)
|
| 41 |
+
video_array = np.moveaxis(video_array, 1, 2)
|
| 42 |
+
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
|
| 43 |
+
video = mimread(name)
|
| 44 |
+
if len(video[0].shape) == 2:
|
| 45 |
+
video = [gray2rgb(frame) for frame in video]
|
| 46 |
+
if frame_shape is not None:
|
| 47 |
+
video = np.array([resize(frame, frame_shape) for frame in video])
|
| 48 |
+
video = np.array(video)
|
| 49 |
+
if video.shape[-1] == 4:
|
| 50 |
+
video = video[..., :3]
|
| 51 |
+
video_array = img_as_float32(video)
|
| 52 |
+
else:
|
| 53 |
+
raise Exception("Unknown file extensions %s" % name)
|
| 54 |
+
|
| 55 |
+
return video_array
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class FramesDataset(Dataset):
|
| 59 |
+
"""
|
| 60 |
+
Dataset of videos, each video can be represented as:
|
| 61 |
+
- an image of concatenated frames
|
| 62 |
+
- '.mp4' or '.gif'
|
| 63 |
+
- folder with all frames
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
|
| 67 |
+
random_seed=0, pairs_list=None, augmentation_params=None):
|
| 68 |
+
self.root_dir = root_dir
|
| 69 |
+
self.videos = os.listdir(root_dir)
|
| 70 |
+
self.frame_shape = frame_shape
|
| 71 |
+
print(self.frame_shape)
|
| 72 |
+
self.pairs_list = pairs_list
|
| 73 |
+
self.id_sampling = id_sampling
|
| 74 |
+
|
| 75 |
+
if os.path.exists(os.path.join(root_dir, 'train')):
|
| 76 |
+
assert os.path.exists(os.path.join(root_dir, 'test'))
|
| 77 |
+
print("Use predefined train-test split.")
|
| 78 |
+
if id_sampling:
|
| 79 |
+
train_videos = {os.path.basename(video).split('#')[0] for video in
|
| 80 |
+
os.listdir(os.path.join(root_dir, 'train'))}
|
| 81 |
+
train_videos = list(train_videos)
|
| 82 |
+
else:
|
| 83 |
+
train_videos = os.listdir(os.path.join(root_dir, 'train'))
|
| 84 |
+
test_videos = os.listdir(os.path.join(root_dir, 'test'))
|
| 85 |
+
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
|
| 86 |
+
else:
|
| 87 |
+
print("Use random train-test split.")
|
| 88 |
+
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
|
| 89 |
+
|
| 90 |
+
if is_train:
|
| 91 |
+
self.videos = train_videos
|
| 92 |
+
else:
|
| 93 |
+
self.videos = test_videos
|
| 94 |
+
|
| 95 |
+
self.is_train = is_train
|
| 96 |
+
|
| 97 |
+
if self.is_train:
|
| 98 |
+
self.transform = AllAugmentationTransform(**augmentation_params)
|
| 99 |
+
else:
|
| 100 |
+
self.transform = None
|
| 101 |
+
|
| 102 |
+
def __len__(self):
|
| 103 |
+
return len(self.videos)
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, idx):
|
| 106 |
+
|
| 107 |
+
if self.is_train and self.id_sampling:
|
| 108 |
+
name = self.videos[idx]
|
| 109 |
+
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
|
| 110 |
+
else:
|
| 111 |
+
name = self.videos[idx]
|
| 112 |
+
path = os.path.join(self.root_dir, name)
|
| 113 |
+
|
| 114 |
+
video_name = os.path.basename(path)
|
| 115 |
+
if self.is_train and os.path.isdir(path):
|
| 116 |
+
|
| 117 |
+
frames = os.listdir(path)
|
| 118 |
+
num_frames = len(frames)
|
| 119 |
+
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
|
| 120 |
+
|
| 121 |
+
if self.frame_shape is not None:
|
| 122 |
+
resize_fn = partial(resize, output_shape=self.frame_shape)
|
| 123 |
+
else:
|
| 124 |
+
resize_fn = img_as_float32
|
| 125 |
+
|
| 126 |
+
if type(frames[0]) is bytes:
|
| 127 |
+
video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in
|
| 128 |
+
frame_idx]
|
| 129 |
+
else:
|
| 130 |
+
video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
|
| 131 |
+
else:
|
| 132 |
+
|
| 133 |
+
video_array = read_video(path, frame_shape=self.frame_shape)
|
| 134 |
+
|
| 135 |
+
num_frames = len(video_array)
|
| 136 |
+
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
|
| 137 |
+
num_frames)
|
| 138 |
+
video_array = video_array[frame_idx]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if self.transform is not None:
|
| 142 |
+
video_array = self.transform(video_array)
|
| 143 |
+
|
| 144 |
+
out = {}
|
| 145 |
+
if self.is_train:
|
| 146 |
+
source = np.array(video_array[0], dtype='float32')
|
| 147 |
+
driving = np.array(video_array[1], dtype='float32')
|
| 148 |
+
|
| 149 |
+
out['driving'] = driving.transpose((2, 0, 1))
|
| 150 |
+
out['source'] = source.transpose((2, 0, 1))
|
| 151 |
+
else:
|
| 152 |
+
video = np.array(video_array, dtype='float32')
|
| 153 |
+
out['video'] = video.transpose((3, 0, 1, 2))
|
| 154 |
+
|
| 155 |
+
out['name'] = video_name
|
| 156 |
+
return out
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class DatasetRepeater(Dataset):
|
| 160 |
+
"""
|
| 161 |
+
Pass several times over the same dataset for better i/o performance
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, dataset, num_repeats=100):
|
| 165 |
+
self.dataset = dataset
|
| 166 |
+
self.num_repeats = num_repeats
|
| 167 |
+
|
| 168 |
+
def __len__(self):
|
| 169 |
+
return self.num_repeats * self.dataset.__len__()
|
| 170 |
+
|
| 171 |
+
def __getitem__(self, idx):
|
| 172 |
+
return self.dataset[idx % self.dataset.__len__()]
|
| 173 |
+
|
TPSMM/logger.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import imageio
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from skimage.draw import circle
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import collections
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Logger:
|
| 14 |
+
def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
|
| 15 |
+
|
| 16 |
+
self.loss_list = []
|
| 17 |
+
self.cpk_dir = log_dir
|
| 18 |
+
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
|
| 19 |
+
if not os.path.exists(self.visualizations_dir):
|
| 20 |
+
os.makedirs(self.visualizations_dir)
|
| 21 |
+
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
|
| 22 |
+
self.zfill_num = zfill_num
|
| 23 |
+
self.visualizer = Visualizer(**visualizer_params)
|
| 24 |
+
self.checkpoint_freq = checkpoint_freq
|
| 25 |
+
self.epoch = 0
|
| 26 |
+
self.best_loss = float('inf')
|
| 27 |
+
self.names = None
|
| 28 |
+
|
| 29 |
+
def log_scores(self, loss_names):
|
| 30 |
+
loss_mean = np.array(self.loss_list).mean(axis=0)
|
| 31 |
+
|
| 32 |
+
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
|
| 33 |
+
loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
|
| 34 |
+
|
| 35 |
+
print(loss_string, file=self.log_file)
|
| 36 |
+
self.loss_list = []
|
| 37 |
+
self.log_file.flush()
|
| 38 |
+
|
| 39 |
+
def visualize_rec(self, inp, out):
|
| 40 |
+
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
|
| 41 |
+
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
|
| 42 |
+
|
| 43 |
+
def save_cpk(self, emergent=False):
|
| 44 |
+
cpk = {k: v.state_dict() for k, v in self.models.items()}
|
| 45 |
+
cpk['epoch'] = self.epoch
|
| 46 |
+
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
|
| 47 |
+
if not (os.path.exists(cpk_path) and emergent):
|
| 48 |
+
torch.save(cpk, cpk_path)
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None,
|
| 52 |
+
bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
|
| 53 |
+
optimizer_avd=None):
|
| 54 |
+
checkpoint = torch.load(checkpoint_path)
|
| 55 |
+
if inpainting_network is not None:
|
| 56 |
+
inpainting_network.load_state_dict(checkpoint['inpainting_network'])
|
| 57 |
+
if kp_detector is not None:
|
| 58 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
| 59 |
+
if bg_predictor is not None and 'bg_predictor' in checkpoint:
|
| 60 |
+
bg_predictor.load_state_dict(checkpoint['bg_predictor'])
|
| 61 |
+
if dense_motion_network is not None:
|
| 62 |
+
dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
|
| 63 |
+
if avd_network is not None:
|
| 64 |
+
if 'avd_network' in checkpoint:
|
| 65 |
+
avd_network.load_state_dict(checkpoint['avd_network'])
|
| 66 |
+
if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint:
|
| 67 |
+
optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor'])
|
| 68 |
+
if optimizer is not None and 'optimizer' in checkpoint:
|
| 69 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 70 |
+
if optimizer_avd is not None:
|
| 71 |
+
if 'optimizer_avd' in checkpoint:
|
| 72 |
+
optimizer_avd.load_state_dict(checkpoint['optimizer_avd'])
|
| 73 |
+
epoch = -1
|
| 74 |
+
if 'epoch' in checkpoint:
|
| 75 |
+
epoch = checkpoint['epoch']
|
| 76 |
+
return epoch
|
| 77 |
+
|
| 78 |
+
def __enter__(self):
|
| 79 |
+
return self
|
| 80 |
+
|
| 81 |
+
def __exit__(self, exc_type, exc_value, tb):
|
| 82 |
+
if 'models' in self.__dict__:
|
| 83 |
+
self.save_cpk()
|
| 84 |
+
self.log_file.close()
|
| 85 |
+
|
| 86 |
+
def log_iter(self, losses):
|
| 87 |
+
losses = collections.OrderedDict(losses.items())
|
| 88 |
+
self.names = list(losses.keys())
|
| 89 |
+
self.loss_list.append(list(losses.values()))
|
| 90 |
+
|
| 91 |
+
def log_epoch(self, epoch, models, inp, out):
|
| 92 |
+
self.epoch = epoch
|
| 93 |
+
self.models = models
|
| 94 |
+
if (self.epoch + 1) % self.checkpoint_freq == 0:
|
| 95 |
+
self.save_cpk()
|
| 96 |
+
self.log_scores(self.names)
|
| 97 |
+
self.visualize_rec(inp, out)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Visualizer:
|
| 101 |
+
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
|
| 102 |
+
self.kp_size = kp_size
|
| 103 |
+
self.draw_border = draw_border
|
| 104 |
+
self.colormap = plt.get_cmap(colormap)
|
| 105 |
+
|
| 106 |
+
def draw_image_with_kp(self, image, kp_array):
|
| 107 |
+
image = np.copy(image)
|
| 108 |
+
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
|
| 109 |
+
kp_array = spatial_size * (kp_array + 1) / 2
|
| 110 |
+
num_kp = kp_array.shape[0]
|
| 111 |
+
for kp_ind, kp in enumerate(kp_array):
|
| 112 |
+
rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
|
| 113 |
+
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
|
| 114 |
+
return image
|
| 115 |
+
|
| 116 |
+
def create_image_column_with_kp(self, images, kp):
|
| 117 |
+
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
|
| 118 |
+
return self.create_image_column(image_array)
|
| 119 |
+
|
| 120 |
+
def create_image_column(self, images):
|
| 121 |
+
if self.draw_border:
|
| 122 |
+
images = np.copy(images)
|
| 123 |
+
images[:, :, [0, -1]] = (1, 1, 1)
|
| 124 |
+
images[:, :, [0, -1]] = (1, 1, 1)
|
| 125 |
+
return np.concatenate(list(images), axis=0)
|
| 126 |
+
|
| 127 |
+
def create_image_grid(self, *args):
|
| 128 |
+
out = []
|
| 129 |
+
for arg in args:
|
| 130 |
+
if type(arg) == tuple:
|
| 131 |
+
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
|
| 132 |
+
else:
|
| 133 |
+
out.append(self.create_image_column(arg))
|
| 134 |
+
return np.concatenate(out, axis=1)
|
| 135 |
+
|
| 136 |
+
def visualize(self, driving, source, out):
|
| 137 |
+
images = []
|
| 138 |
+
|
| 139 |
+
# Source image with keypoints
|
| 140 |
+
source = source.data.cpu()
|
| 141 |
+
kp_source = out['kp_source']['fg_kp'].data.cpu().numpy()
|
| 142 |
+
source = np.transpose(source, [0, 2, 3, 1])
|
| 143 |
+
images.append((source, kp_source))
|
| 144 |
+
|
| 145 |
+
# Equivariance visualization
|
| 146 |
+
if 'transformed_frame' in out:
|
| 147 |
+
transformed = out['transformed_frame'].data.cpu().numpy()
|
| 148 |
+
transformed = np.transpose(transformed, [0, 2, 3, 1])
|
| 149 |
+
transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy()
|
| 150 |
+
images.append((transformed, transformed_kp))
|
| 151 |
+
|
| 152 |
+
# Driving image with keypoints
|
| 153 |
+
kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy()
|
| 154 |
+
driving = driving.data.cpu().numpy()
|
| 155 |
+
driving = np.transpose(driving, [0, 2, 3, 1])
|
| 156 |
+
images.append((driving, kp_driving))
|
| 157 |
+
|
| 158 |
+
# Deformed image
|
| 159 |
+
if 'deformed' in out:
|
| 160 |
+
deformed = out['deformed'].data.cpu().numpy()
|
| 161 |
+
deformed = np.transpose(deformed, [0, 2, 3, 1])
|
| 162 |
+
images.append(deformed)
|
| 163 |
+
|
| 164 |
+
# Result with and without keypoints
|
| 165 |
+
prediction = out['prediction'].data.cpu().numpy()
|
| 166 |
+
prediction = np.transpose(prediction, [0, 2, 3, 1])
|
| 167 |
+
if 'kp_norm' in out:
|
| 168 |
+
kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy()
|
| 169 |
+
images.append((prediction, kp_norm))
|
| 170 |
+
images.append(prediction)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
## Occlusion map
|
| 174 |
+
if 'occlusion_map' in out:
|
| 175 |
+
for i in range(len(out['occlusion_map'])):
|
| 176 |
+
occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1)
|
| 177 |
+
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
|
| 178 |
+
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
|
| 179 |
+
images.append(occlusion_map)
|
| 180 |
+
|
| 181 |
+
# Deformed images according to each individual transform
|
| 182 |
+
if 'deformed_source' in out:
|
| 183 |
+
full_mask = []
|
| 184 |
+
for i in range(out['deformed_source'].shape[1]):
|
| 185 |
+
image = out['deformed_source'][:, i].data.cpu()
|
| 186 |
+
# import ipdb;ipdb.set_trace()
|
| 187 |
+
image = F.interpolate(image, size=source.shape[1:3])
|
| 188 |
+
mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
|
| 189 |
+
mask = F.interpolate(mask, size=source.shape[1:3])
|
| 190 |
+
image = np.transpose(image.numpy(), (0, 2, 3, 1))
|
| 191 |
+
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
|
| 192 |
+
|
| 193 |
+
if i != 0:
|
| 194 |
+
color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3]
|
| 195 |
+
else:
|
| 196 |
+
color = np.array((0, 0, 0))
|
| 197 |
+
|
| 198 |
+
color = color.reshape((1, 1, 1, 3))
|
| 199 |
+
|
| 200 |
+
images.append(image)
|
| 201 |
+
if i != 0:
|
| 202 |
+
images.append(mask * color)
|
| 203 |
+
else:
|
| 204 |
+
images.append(mask)
|
| 205 |
+
|
| 206 |
+
full_mask.append(mask * color)
|
| 207 |
+
|
| 208 |
+
images.append(sum(full_mask))
|
| 209 |
+
|
| 210 |
+
image = self.create_image_grid(*images)
|
| 211 |
+
image = (255 * image).astype(np.uint8)
|
| 212 |
+
return image
|
TPSMM/modules/avd_network.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AVDNetwork(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Animation via Disentanglement network
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
|
| 12 |
+
super(AVDNetwork, self).__init__()
|
| 13 |
+
input_size = 5*2 * num_tps
|
| 14 |
+
self.num_tps = num_tps
|
| 15 |
+
|
| 16 |
+
self.id_encoder = nn.Sequential(
|
| 17 |
+
nn.Linear(input_size, 256),
|
| 18 |
+
nn.BatchNorm1d(256),
|
| 19 |
+
nn.ReLU(inplace=True),
|
| 20 |
+
nn.Linear(256, 512),
|
| 21 |
+
nn.BatchNorm1d(512),
|
| 22 |
+
nn.ReLU(inplace=True),
|
| 23 |
+
nn.Linear(512, 1024),
|
| 24 |
+
nn.BatchNorm1d(1024),
|
| 25 |
+
nn.ReLU(inplace=True),
|
| 26 |
+
nn.Linear(1024, id_bottle_size)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.pose_encoder = nn.Sequential(
|
| 30 |
+
nn.Linear(input_size, 256),
|
| 31 |
+
nn.BatchNorm1d(256),
|
| 32 |
+
nn.ReLU(inplace=True),
|
| 33 |
+
nn.Linear(256, 512),
|
| 34 |
+
nn.BatchNorm1d(512),
|
| 35 |
+
nn.ReLU(inplace=True),
|
| 36 |
+
nn.Linear(512, 1024),
|
| 37 |
+
nn.BatchNorm1d(1024),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
nn.Linear(1024, pose_bottle_size)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.decoder = nn.Sequential(
|
| 43 |
+
nn.Linear(pose_bottle_size + id_bottle_size, 1024),
|
| 44 |
+
nn.BatchNorm1d(1024),
|
| 45 |
+
nn.ReLU(),
|
| 46 |
+
nn.Linear(1024, 512),
|
| 47 |
+
nn.BatchNorm1d(512),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Linear(512, 256),
|
| 50 |
+
nn.BatchNorm1d(256),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
nn.Linear(256, input_size)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, kp_source, kp_random):
|
| 56 |
+
|
| 57 |
+
bs = kp_source['fg_kp'].shape[0]
|
| 58 |
+
|
| 59 |
+
pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
|
| 60 |
+
id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
|
| 61 |
+
|
| 62 |
+
rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
|
| 63 |
+
|
| 64 |
+
rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
|
| 65 |
+
return rec
|
TPSMM/modules/bg_motion_predictor.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import models
|
| 4 |
+
|
| 5 |
+
class BGMotionPredictor(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(BGMotionPredictor, self).__init__()
|
| 12 |
+
self.bg_encoder = models.resnet18(pretrained=False)
|
| 13 |
+
self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
| 14 |
+
num_features = self.bg_encoder.fc.in_features
|
| 15 |
+
self.bg_encoder.fc = nn.Linear(num_features, 6)
|
| 16 |
+
self.bg_encoder.fc.weight.data.zero_()
|
| 17 |
+
self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
|
| 18 |
+
|
| 19 |
+
def forward(self, source_image, driving_image):
|
| 20 |
+
bs = source_image.shape[0]
|
| 21 |
+
out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
|
| 22 |
+
prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
|
| 23 |
+
out[:, :2, :] = prediction.view(bs, 2, 3)
|
| 24 |
+
return out
|
TPSMM/modules/dense_motion.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
| 5 |
+
from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
class DenseMotionNetwork(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Module that estimating an optical flow and multi-resolution occlusion masks
|
| 11 |
+
from K TPS transformations and an affine transformation.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
|
| 15 |
+
scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
|
| 16 |
+
super(DenseMotionNetwork, self).__init__()
|
| 17 |
+
|
| 18 |
+
if scale_factor != 1:
|
| 19 |
+
self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
|
| 20 |
+
self.scale_factor = scale_factor
|
| 21 |
+
self.multi_mask = multi_mask
|
| 22 |
+
|
| 23 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
|
| 24 |
+
max_features=max_features, num_blocks=num_blocks)
|
| 25 |
+
|
| 26 |
+
hourglass_output_size = self.hourglass.out_channels
|
| 27 |
+
self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
|
| 28 |
+
|
| 29 |
+
if multi_mask:
|
| 30 |
+
up = []
|
| 31 |
+
self.up_nums = int(math.log(1/scale_factor, 2))
|
| 32 |
+
self.occlusion_num = 4
|
| 33 |
+
|
| 34 |
+
channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
|
| 35 |
+
for i in range(self.up_nums):
|
| 36 |
+
up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
|
| 37 |
+
self.up = nn.ModuleList(up)
|
| 38 |
+
|
| 39 |
+
channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
|
| 40 |
+
for i in range(self.up_nums):
|
| 41 |
+
channel.append(hourglass_output_size[-1]//(2**(i+1)))
|
| 42 |
+
occlusion = []
|
| 43 |
+
|
| 44 |
+
for i in range(self.occlusion_num):
|
| 45 |
+
occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
|
| 46 |
+
self.occlusion = nn.ModuleList(occlusion)
|
| 47 |
+
else:
|
| 48 |
+
occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
|
| 49 |
+
self.occlusion = nn.ModuleList(occlusion)
|
| 50 |
+
|
| 51 |
+
self.num_tps = num_tps
|
| 52 |
+
self.bg = bg
|
| 53 |
+
self.kp_variance = kp_variance
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
| 57 |
+
|
| 58 |
+
spatial_size = source_image.shape[2:]
|
| 59 |
+
gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
|
| 60 |
+
gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
|
| 61 |
+
heatmap = gaussian_driving - gaussian_source
|
| 62 |
+
|
| 63 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
|
| 64 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
| 65 |
+
|
| 66 |
+
return heatmap
|
| 67 |
+
|
| 68 |
+
def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
|
| 69 |
+
# K TPS transformaions
|
| 70 |
+
bs, _, h, w = source_image.shape
|
| 71 |
+
kp_1 = kp_driving['fg_kp']
|
| 72 |
+
kp_2 = kp_source['fg_kp']
|
| 73 |
+
kp_1 = kp_1.view(bs, -1, 5, 2)
|
| 74 |
+
kp_2 = kp_2.view(bs, -1, 5, 2)
|
| 75 |
+
trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
|
| 76 |
+
driving_to_source = trans.transform_frame(source_image)
|
| 77 |
+
|
| 78 |
+
identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
|
| 79 |
+
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
| 80 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
| 81 |
+
|
| 82 |
+
# affine background transformation
|
| 83 |
+
if not (bg_param is None):
|
| 84 |
+
identity_grid = to_homogeneous(identity_grid)
|
| 85 |
+
identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
|
| 86 |
+
identity_grid = from_homogeneous(identity_grid)
|
| 87 |
+
|
| 88 |
+
transformations = torch.cat([identity_grid, driving_to_source], dim=1)
|
| 89 |
+
return transformations
|
| 90 |
+
|
| 91 |
+
def create_deformed_source_image(self, source_image, transformations):
|
| 92 |
+
|
| 93 |
+
bs, _, h, w = source_image.shape
|
| 94 |
+
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
|
| 95 |
+
source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
|
| 96 |
+
transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
|
| 97 |
+
deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
|
| 98 |
+
deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
|
| 99 |
+
return deformed
|
| 100 |
+
|
| 101 |
+
def dropout_softmax(self, X, P):
|
| 102 |
+
'''
|
| 103 |
+
Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
|
| 104 |
+
'''
|
| 105 |
+
drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
|
| 106 |
+
drop[..., 0] = 1
|
| 107 |
+
drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
|
| 108 |
+
|
| 109 |
+
maxx = X.max(1).values.unsqueeze_(1)
|
| 110 |
+
X = X - maxx
|
| 111 |
+
X_exp = X.exp()
|
| 112 |
+
X[:,1:,...] /= (1-P)
|
| 113 |
+
mask_bool =(drop == 0)
|
| 114 |
+
X_exp = X_exp.masked_fill(mask_bool, 0)
|
| 115 |
+
partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
|
| 116 |
+
return X_exp / partition
|
| 117 |
+
|
| 118 |
+
def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
|
| 119 |
+
if self.scale_factor != 1:
|
| 120 |
+
source_image = self.down(source_image)
|
| 121 |
+
|
| 122 |
+
bs, _, h, w = source_image.shape
|
| 123 |
+
|
| 124 |
+
out_dict = dict()
|
| 125 |
+
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
| 126 |
+
transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
|
| 127 |
+
deformed_source = self.create_deformed_source_image(source_image, transformations)
|
| 128 |
+
out_dict['deformed_source'] = deformed_source
|
| 129 |
+
# out_dict['transformations'] = transformations
|
| 130 |
+
deformed_source = deformed_source.view(bs,-1,h,w)
|
| 131 |
+
input = torch.cat([heatmap_representation, deformed_source], dim=1)
|
| 132 |
+
input = input.view(bs, -1, h, w)
|
| 133 |
+
|
| 134 |
+
prediction = self.hourglass(input, mode = 1)
|
| 135 |
+
|
| 136 |
+
contribution_maps = self.maps(prediction[-1])
|
| 137 |
+
if(dropout_flag):
|
| 138 |
+
contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
|
| 139 |
+
else:
|
| 140 |
+
contribution_maps = F.softmax(contribution_maps, dim=1)
|
| 141 |
+
out_dict['contribution_maps'] = contribution_maps
|
| 142 |
+
|
| 143 |
+
# Combine the K+1 transformations
|
| 144 |
+
# Eq(6) in the paper
|
| 145 |
+
contribution_maps = contribution_maps.unsqueeze(2)
|
| 146 |
+
transformations = transformations.permute(0, 1, 4, 2, 3)
|
| 147 |
+
deformation = (transformations * contribution_maps).sum(dim=1)
|
| 148 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
| 149 |
+
|
| 150 |
+
out_dict['deformation'] = deformation # Optical Flow
|
| 151 |
+
|
| 152 |
+
occlusion_map = []
|
| 153 |
+
if self.multi_mask:
|
| 154 |
+
for i in range(self.occlusion_num-self.up_nums):
|
| 155 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
|
| 156 |
+
prediction = prediction[-1]
|
| 157 |
+
for i in range(self.up_nums):
|
| 158 |
+
prediction = self.up[i](prediction)
|
| 159 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
|
| 160 |
+
else:
|
| 161 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
|
| 162 |
+
|
| 163 |
+
out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
|
| 164 |
+
return out_dict
|
TPSMM/modules/inpainting_network.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
| 5 |
+
from modules.dense_motion import DenseMotionNetwork
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InpaintingNetwork(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Inpaint the missing regions and reconstruct the Driving image.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
|
| 13 |
+
super(InpaintingNetwork, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.num_down_blocks = num_down_blocks
|
| 16 |
+
self.multi_mask = multi_mask
|
| 17 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
| 18 |
+
|
| 19 |
+
down_blocks = []
|
| 20 |
+
up_blocks = []
|
| 21 |
+
resblock = []
|
| 22 |
+
for i in range(num_down_blocks):
|
| 23 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
| 24 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
| 25 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
| 26 |
+
decoder_in_feature = out_features * 2
|
| 27 |
+
if i==num_down_blocks-1:
|
| 28 |
+
decoder_in_feature = out_features
|
| 29 |
+
up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1)))
|
| 30 |
+
resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
|
| 31 |
+
resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
|
| 32 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
| 33 |
+
self.up_blocks = nn.ModuleList(up_blocks[::-1])
|
| 34 |
+
self.resblock = nn.ModuleList(resblock[::-1])
|
| 35 |
+
|
| 36 |
+
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
| 37 |
+
self.num_channels = num_channels
|
| 38 |
+
|
| 39 |
+
def deform_input(self, inp, deformation):
|
| 40 |
+
_, h_old, w_old, _ = deformation.shape
|
| 41 |
+
_, _, h, w = inp.shape
|
| 42 |
+
if h_old != h or w_old != w:
|
| 43 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
| 44 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
|
| 45 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
| 46 |
+
return F.grid_sample(inp, deformation,align_corners=True)
|
| 47 |
+
|
| 48 |
+
def occlude_input(self, inp, occlusion_map):
|
| 49 |
+
if not self.multi_mask:
|
| 50 |
+
if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
|
| 51 |
+
occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
|
| 52 |
+
out = inp * occlusion_map
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
def forward(self, source_image, dense_motion):
|
| 56 |
+
out = self.first(source_image)
|
| 57 |
+
encoder_map = [out]
|
| 58 |
+
for i in range(len(self.down_blocks)):
|
| 59 |
+
out = self.down_blocks[i](out)
|
| 60 |
+
encoder_map.append(out)
|
| 61 |
+
|
| 62 |
+
output_dict = {}
|
| 63 |
+
output_dict['contribution_maps'] = dense_motion['contribution_maps']
|
| 64 |
+
output_dict['deformed_source'] = dense_motion['deformed_source']
|
| 65 |
+
|
| 66 |
+
occlusion_map = dense_motion['occlusion_map']
|
| 67 |
+
output_dict['occlusion_map'] = occlusion_map
|
| 68 |
+
|
| 69 |
+
deformation = dense_motion['deformation']
|
| 70 |
+
out_ij = self.deform_input(out.detach(), deformation)
|
| 71 |
+
out = self.deform_input(out, deformation)
|
| 72 |
+
|
| 73 |
+
out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
|
| 74 |
+
out = self.occlude_input(out, occlusion_map[0])
|
| 75 |
+
|
| 76 |
+
warped_encoder_maps = []
|
| 77 |
+
warped_encoder_maps.append(out_ij)
|
| 78 |
+
|
| 79 |
+
for i in range(self.num_down_blocks):
|
| 80 |
+
|
| 81 |
+
out = self.resblock[2*i](out)
|
| 82 |
+
out = self.resblock[2*i+1](out)
|
| 83 |
+
out = self.up_blocks[i](out)
|
| 84 |
+
|
| 85 |
+
encode_i = encoder_map[-(i+2)]
|
| 86 |
+
encode_ij = self.deform_input(encode_i.detach(), deformation)
|
| 87 |
+
encode_i = self.deform_input(encode_i, deformation)
|
| 88 |
+
|
| 89 |
+
occlusion_ind = 0
|
| 90 |
+
if self.multi_mask:
|
| 91 |
+
occlusion_ind = i+1
|
| 92 |
+
encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
|
| 93 |
+
encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
|
| 94 |
+
warped_encoder_maps.append(encode_ij)
|
| 95 |
+
|
| 96 |
+
if(i==self.num_down_blocks-1):
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
out = torch.cat([out, encode_i], 1)
|
| 100 |
+
|
| 101 |
+
deformed_source = self.deform_input(source_image, deformation)
|
| 102 |
+
output_dict["deformed"] = deformed_source
|
| 103 |
+
output_dict["warped_encoder_maps"] = warped_encoder_maps
|
| 104 |
+
|
| 105 |
+
occlusion_last = occlusion_map[-1]
|
| 106 |
+
if not self.multi_mask:
|
| 107 |
+
occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
|
| 108 |
+
|
| 109 |
+
out = out * (1 - occlusion_last) + encode_i
|
| 110 |
+
out = self.final(out)
|
| 111 |
+
out = torch.sigmoid(out)
|
| 112 |
+
out = out * (1 - occlusion_last) + deformed_source * occlusion_last
|
| 113 |
+
output_dict["prediction"] = out
|
| 114 |
+
|
| 115 |
+
return output_dict
|
| 116 |
+
|
| 117 |
+
def get_encode(self, driver_image, occlusion_map):
|
| 118 |
+
out = self.first(driver_image)
|
| 119 |
+
encoder_map = []
|
| 120 |
+
encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
|
| 121 |
+
for i in range(len(self.down_blocks)):
|
| 122 |
+
out = self.down_blocks[i](out.detach())
|
| 123 |
+
out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
|
| 124 |
+
encoder_map.append(out_mask.detach())
|
| 125 |
+
|
| 126 |
+
return encoder_map
|
| 127 |
+
|
TPSMM/modules/keypoint_detector.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import models
|
| 4 |
+
|
| 5 |
+
class KPDetector(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Predict K*5 keypoints.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, num_tps, **kwargs):
|
| 11 |
+
super(KPDetector, self).__init__()
|
| 12 |
+
self.num_tps = num_tps
|
| 13 |
+
|
| 14 |
+
self.fg_encoder = models.resnet18(pretrained=False)
|
| 15 |
+
num_features = self.fg_encoder.fc.in_features
|
| 16 |
+
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def forward(self, image):
|
| 20 |
+
|
| 21 |
+
fg_kp = self.fg_encoder(image)
|
| 22 |
+
bs, _, = fg_kp.shape
|
| 23 |
+
fg_kp = torch.sigmoid(fg_kp)
|
| 24 |
+
fg_kp = fg_kp * 2 - 1
|
| 25 |
+
out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
|
| 26 |
+
|
| 27 |
+
return out
|
TPSMM/modules/model.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from modules.util import AntiAliasInterpolation2d, TPS
|
| 5 |
+
from torchvision import models
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Vgg19(torch.nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Vgg19 network for perceptual loss. See Sec 3.3.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, requires_grad=False):
|
| 14 |
+
super(Vgg19, self).__init__()
|
| 15 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 16 |
+
self.slice1 = torch.nn.Sequential()
|
| 17 |
+
self.slice2 = torch.nn.Sequential()
|
| 18 |
+
self.slice3 = torch.nn.Sequential()
|
| 19 |
+
self.slice4 = torch.nn.Sequential()
|
| 20 |
+
self.slice5 = torch.nn.Sequential()
|
| 21 |
+
for x in range(2):
|
| 22 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 23 |
+
for x in range(2, 7):
|
| 24 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 25 |
+
for x in range(7, 12):
|
| 26 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 27 |
+
for x in range(12, 21):
|
| 28 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 29 |
+
for x in range(21, 30):
|
| 30 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 31 |
+
|
| 32 |
+
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
| 33 |
+
requires_grad=False)
|
| 34 |
+
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
| 35 |
+
requires_grad=False)
|
| 36 |
+
|
| 37 |
+
if not requires_grad:
|
| 38 |
+
for param in self.parameters():
|
| 39 |
+
param.requires_grad = False
|
| 40 |
+
|
| 41 |
+
def forward(self, X):
|
| 42 |
+
X = (X - self.mean) / self.std
|
| 43 |
+
h_relu1 = self.slice1(X)
|
| 44 |
+
h_relu2 = self.slice2(h_relu1)
|
| 45 |
+
h_relu3 = self.slice3(h_relu2)
|
| 46 |
+
h_relu4 = self.slice4(h_relu3)
|
| 47 |
+
h_relu5 = self.slice5(h_relu4)
|
| 48 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 49 |
+
return out
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ImagePyramide(torch.nn.Module):
|
| 53 |
+
"""
|
| 54 |
+
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
| 55 |
+
"""
|
| 56 |
+
def __init__(self, scales, num_channels):
|
| 57 |
+
super(ImagePyramide, self).__init__()
|
| 58 |
+
downs = {}
|
| 59 |
+
for scale in scales:
|
| 60 |
+
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
| 61 |
+
self.downs = nn.ModuleDict(downs)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
out_dict = {}
|
| 65 |
+
for scale, down_module in self.downs.items():
|
| 66 |
+
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
| 67 |
+
return out_dict
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def detach_kp(kp):
|
| 71 |
+
return {key: value.detach() for key, value in kp.items()}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class GeneratorFullModel(torch.nn.Module):
|
| 75 |
+
"""
|
| 76 |
+
Merge all generator related updates into single model for better multi-gpu usage
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
|
| 80 |
+
super(GeneratorFullModel, self).__init__()
|
| 81 |
+
self.kp_extractor = kp_extractor
|
| 82 |
+
self.inpainting_network = inpainting_network
|
| 83 |
+
self.dense_motion_network = dense_motion_network
|
| 84 |
+
|
| 85 |
+
self.bg_predictor = None
|
| 86 |
+
if bg_predictor:
|
| 87 |
+
self.bg_predictor = bg_predictor
|
| 88 |
+
self.bg_start = train_params['bg_start']
|
| 89 |
+
|
| 90 |
+
self.train_params = train_params
|
| 91 |
+
self.scales = train_params['scales']
|
| 92 |
+
|
| 93 |
+
self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
|
| 94 |
+
if torch.cuda.is_available():
|
| 95 |
+
self.pyramid = self.pyramid.cuda()
|
| 96 |
+
|
| 97 |
+
self.loss_weights = train_params['loss_weights']
|
| 98 |
+
self.dropout_epoch = train_params['dropout_epoch']
|
| 99 |
+
self.dropout_maxp = train_params['dropout_maxp']
|
| 100 |
+
self.dropout_inc_epoch = train_params['dropout_inc_epoch']
|
| 101 |
+
self.dropout_startp =train_params['dropout_startp']
|
| 102 |
+
|
| 103 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
| 104 |
+
self.vgg = Vgg19()
|
| 105 |
+
if torch.cuda.is_available():
|
| 106 |
+
self.vgg = self.vgg.cuda()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def forward(self, x, epoch):
|
| 110 |
+
kp_source = self.kp_extractor(x['source'])
|
| 111 |
+
kp_driving = self.kp_extractor(x['driving'])
|
| 112 |
+
bg_param = None
|
| 113 |
+
if self.bg_predictor:
|
| 114 |
+
if(epoch>=self.bg_start):
|
| 115 |
+
bg_param = self.bg_predictor(x['source'], x['driving'])
|
| 116 |
+
|
| 117 |
+
if(epoch>=self.dropout_epoch):
|
| 118 |
+
dropout_flag = False
|
| 119 |
+
dropout_p = 0
|
| 120 |
+
else:
|
| 121 |
+
# dropout_p will linearly increase from dropout_startp to dropout_maxp
|
| 122 |
+
dropout_flag = True
|
| 123 |
+
dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
|
| 124 |
+
|
| 125 |
+
dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
|
| 126 |
+
kp_source=kp_source, bg_param = bg_param,
|
| 127 |
+
dropout_flag = dropout_flag, dropout_p = dropout_p)
|
| 128 |
+
generated = self.inpainting_network(x['source'], dense_motion)
|
| 129 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
| 130 |
+
|
| 131 |
+
loss_values = {}
|
| 132 |
+
|
| 133 |
+
pyramide_real = self.pyramid(x['driving'])
|
| 134 |
+
pyramide_generated = self.pyramid(generated['prediction'])
|
| 135 |
+
|
| 136 |
+
# reconstruction loss
|
| 137 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
| 138 |
+
value_total = 0
|
| 139 |
+
for scale in self.scales:
|
| 140 |
+
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
| 141 |
+
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
| 142 |
+
|
| 143 |
+
for i, weight in enumerate(self.loss_weights['perceptual']):
|
| 144 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
| 145 |
+
value_total += self.loss_weights['perceptual'][i] * value
|
| 146 |
+
loss_values['perceptual'] = value_total
|
| 147 |
+
|
| 148 |
+
# equivariance loss
|
| 149 |
+
if self.loss_weights['equivariance_value'] != 0:
|
| 150 |
+
transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
|
| 151 |
+
transform_grid = transform_random.transform_frame(x['driving'])
|
| 152 |
+
transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
|
| 153 |
+
transformed_kp = self.kp_extractor(transformed_frame)
|
| 154 |
+
|
| 155 |
+
generated['transformed_frame'] = transformed_frame
|
| 156 |
+
generated['transformed_kp'] = transformed_kp
|
| 157 |
+
|
| 158 |
+
warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
|
| 159 |
+
kp_d = kp_driving['fg_kp']
|
| 160 |
+
value = torch.abs(kp_d - warped).mean()
|
| 161 |
+
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
| 162 |
+
|
| 163 |
+
# warp loss
|
| 164 |
+
if self.loss_weights['warp_loss'] != 0:
|
| 165 |
+
occlusion_map = generated['occlusion_map']
|
| 166 |
+
encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
|
| 167 |
+
decode_map = generated['warped_encoder_maps']
|
| 168 |
+
value = 0
|
| 169 |
+
for i in range(len(encode_map)):
|
| 170 |
+
value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
|
| 171 |
+
|
| 172 |
+
loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
|
| 173 |
+
|
| 174 |
+
# bg loss
|
| 175 |
+
if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
|
| 176 |
+
bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
|
| 177 |
+
value = torch.matmul(bg_param, bg_param_reverse)
|
| 178 |
+
eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
|
| 179 |
+
value = torch.abs(eye - value).mean()
|
| 180 |
+
loss_values['bg'] = self.loss_weights['bg'] * value
|
| 181 |
+
|
| 182 |
+
return loss_values, generated
|
TPSMM/modules/util.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TPS:
|
| 7 |
+
'''
|
| 8 |
+
TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
|
| 9 |
+
'''
|
| 10 |
+
def __init__(self, mode, bs, **kwargs):
|
| 11 |
+
self.bs = bs
|
| 12 |
+
self.mode = mode
|
| 13 |
+
if mode == 'random':
|
| 14 |
+
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
| 15 |
+
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
| 16 |
+
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
| 17 |
+
self.control_points = self.control_points.unsqueeze(0)
|
| 18 |
+
self.control_params = torch.normal(mean=0,
|
| 19 |
+
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
| 20 |
+
elif mode == 'kp':
|
| 21 |
+
kp_1 = kwargs["kp_1"]
|
| 22 |
+
kp_2 = kwargs["kp_2"]
|
| 23 |
+
device = kp_1.device
|
| 24 |
+
kp_type = kp_1.type()
|
| 25 |
+
self.gs = kp_1.shape[1]
|
| 26 |
+
n = kp_1.shape[2]
|
| 27 |
+
K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
|
| 28 |
+
K = K**2
|
| 29 |
+
K = K * torch.log(K+1e-9)
|
| 30 |
+
|
| 31 |
+
one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
|
| 32 |
+
kp_1p = torch.cat([kp_1,one1], 3)
|
| 33 |
+
|
| 34 |
+
zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
|
| 35 |
+
P = torch.cat([kp_1p, zero],2)
|
| 36 |
+
L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
|
| 37 |
+
L = torch.cat([L,P],3)
|
| 38 |
+
|
| 39 |
+
zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
|
| 40 |
+
Y = torch.cat([kp_2, zero], 2)
|
| 41 |
+
one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
|
| 42 |
+
L = L + one
|
| 43 |
+
|
| 44 |
+
param = torch.matmul(torch.inverse(L),Y)
|
| 45 |
+
self.theta = param[:,:,n:,:].permute(0,1,3,2)
|
| 46 |
+
|
| 47 |
+
self.control_points = kp_1
|
| 48 |
+
self.control_params = param[:,:,:n,:]
|
| 49 |
+
else:
|
| 50 |
+
raise Exception("Error TPS mode")
|
| 51 |
+
|
| 52 |
+
def transform_frame(self, frame):
|
| 53 |
+
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
|
| 54 |
+
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
| 55 |
+
shape = [self.bs, frame.shape[2], frame.shape[3], 2]
|
| 56 |
+
if self.mode == 'kp':
|
| 57 |
+
shape.insert(1, self.gs)
|
| 58 |
+
grid = self.warp_coordinates(grid).view(*shape)
|
| 59 |
+
return grid
|
| 60 |
+
|
| 61 |
+
def warp_coordinates(self, coordinates):
|
| 62 |
+
theta = self.theta.type(coordinates.type()).to(coordinates.device)
|
| 63 |
+
control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
|
| 64 |
+
control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
|
| 65 |
+
|
| 66 |
+
if self.mode == 'kp':
|
| 67 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
|
| 68 |
+
|
| 69 |
+
distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
|
| 70 |
+
|
| 71 |
+
distances = distances ** 2
|
| 72 |
+
result = distances.sum(-1)
|
| 73 |
+
result = result * torch.log(result + 1e-9)
|
| 74 |
+
result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
|
| 75 |
+
transformed = transformed.permute(0, 1, 3, 2) + result
|
| 76 |
+
|
| 77 |
+
elif self.mode == 'random':
|
| 78 |
+
theta = theta.unsqueeze(1)
|
| 79 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
| 80 |
+
transformed = transformed.squeeze(-1)
|
| 81 |
+
ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
| 82 |
+
distances = ances ** 2
|
| 83 |
+
|
| 84 |
+
result = distances.sum(-1)
|
| 85 |
+
result = result * torch.log(result + 1e-9)
|
| 86 |
+
result = result * control_params
|
| 87 |
+
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
| 88 |
+
transformed = transformed + result
|
| 89 |
+
else:
|
| 90 |
+
raise Exception("Error TPS mode")
|
| 91 |
+
|
| 92 |
+
return transformed
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
| 96 |
+
"""
|
| 97 |
+
Transform a keypoint into gaussian like representation
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
|
| 101 |
+
number_of_leading_dimensions = len(kp.shape) - 1
|
| 102 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
| 103 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
| 104 |
+
repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
| 105 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
| 106 |
+
|
| 107 |
+
# Preprocess kp shape
|
| 108 |
+
shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
| 109 |
+
kp = kp.view(*shape)
|
| 110 |
+
|
| 111 |
+
mean_sub = (coordinate_grid - kp)
|
| 112 |
+
|
| 113 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
| 114 |
+
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def make_coordinate_grid(spatial_size, type):
|
| 119 |
+
"""
|
| 120 |
+
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
| 121 |
+
"""
|
| 122 |
+
h, w = spatial_size
|
| 123 |
+
x = torch.arange(w).type(type)
|
| 124 |
+
y = torch.arange(h).type(type)
|
| 125 |
+
|
| 126 |
+
x = (2 * (x / (w - 1)) - 1)
|
| 127 |
+
y = (2 * (y / (h - 1)) - 1)
|
| 128 |
+
|
| 129 |
+
yy = y.view(-1, 1).repeat(1, w)
|
| 130 |
+
xx = x.view(1, -1).repeat(h, 1)
|
| 131 |
+
|
| 132 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
| 133 |
+
|
| 134 |
+
return meshed
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class ResBlock2d(nn.Module):
|
| 138 |
+
"""
|
| 139 |
+
Res block, preserve spatial resolution.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, in_features, kernel_size, padding):
|
| 143 |
+
super(ResBlock2d, self).__init__()
|
| 144 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
| 145 |
+
padding=padding)
|
| 146 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
| 147 |
+
padding=padding)
|
| 148 |
+
self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
|
| 149 |
+
self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
out = self.norm1(x)
|
| 153 |
+
out = F.relu(out)
|
| 154 |
+
out = self.conv1(out)
|
| 155 |
+
out = self.norm2(out)
|
| 156 |
+
out = F.relu(out)
|
| 157 |
+
out = self.conv2(out)
|
| 158 |
+
out += x
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class UpBlock2d(nn.Module):
|
| 163 |
+
"""
|
| 164 |
+
Upsampling block for use in decoder.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
| 168 |
+
super(UpBlock2d, self).__init__()
|
| 169 |
+
|
| 170 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
| 171 |
+
padding=padding, groups=groups)
|
| 172 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
out = F.interpolate(x, scale_factor=2)
|
| 176 |
+
out = self.conv(out)
|
| 177 |
+
out = self.norm(out)
|
| 178 |
+
out = F.relu(out)
|
| 179 |
+
return out
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class DownBlock2d(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
Downsampling block for use in encoder.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
| 188 |
+
super(DownBlock2d, self).__init__()
|
| 189 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
| 190 |
+
padding=padding, groups=groups)
|
| 191 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
| 192 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
out = self.conv(x)
|
| 196 |
+
out = self.norm(out)
|
| 197 |
+
out = F.relu(out)
|
| 198 |
+
out = self.pool(out)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class SameBlock2d(nn.Module):
|
| 203 |
+
"""
|
| 204 |
+
Simple block, preserve spatial resolution.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
| 208 |
+
super(SameBlock2d, self).__init__()
|
| 209 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
| 210 |
+
kernel_size=kernel_size, padding=padding, groups=groups)
|
| 211 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
| 212 |
+
|
| 213 |
+
def forward(self, x):
|
| 214 |
+
out = self.conv(x)
|
| 215 |
+
out = self.norm(out)
|
| 216 |
+
out = F.relu(out)
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class Encoder(nn.Module):
|
| 221 |
+
"""
|
| 222 |
+
Hourglass Encoder
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
| 226 |
+
super(Encoder, self).__init__()
|
| 227 |
+
|
| 228 |
+
down_blocks = []
|
| 229 |
+
for i in range(num_blocks):
|
| 230 |
+
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
| 231 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
| 232 |
+
kernel_size=3, padding=1))
|
| 233 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
outs = [x]
|
| 237 |
+
#print('encoder:' ,outs[-1].shape)
|
| 238 |
+
for down_block in self.down_blocks:
|
| 239 |
+
outs.append(down_block(outs[-1]))
|
| 240 |
+
#print('encoder:' ,outs[-1].shape)
|
| 241 |
+
return outs
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Decoder(nn.Module):
|
| 245 |
+
"""
|
| 246 |
+
Hourglass Decoder
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
| 250 |
+
super(Decoder, self).__init__()
|
| 251 |
+
|
| 252 |
+
up_blocks = []
|
| 253 |
+
self.out_channels = []
|
| 254 |
+
for i in range(num_blocks)[::-1]:
|
| 255 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
| 256 |
+
self.out_channels.append(in_filters)
|
| 257 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
| 258 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
| 259 |
+
|
| 260 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
| 261 |
+
self.out_channels.append(block_expansion + in_features)
|
| 262 |
+
# self.out_filters = block_expansion + in_features
|
| 263 |
+
|
| 264 |
+
def forward(self, x, mode = 0):
|
| 265 |
+
out = x.pop()
|
| 266 |
+
outs = []
|
| 267 |
+
for up_block in self.up_blocks:
|
| 268 |
+
out = up_block(out)
|
| 269 |
+
skip = x.pop()
|
| 270 |
+
out = torch.cat([out, skip], dim=1)
|
| 271 |
+
outs.append(out)
|
| 272 |
+
if(mode == 0):
|
| 273 |
+
return out
|
| 274 |
+
else:
|
| 275 |
+
return outs
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class Hourglass(nn.Module):
|
| 279 |
+
"""
|
| 280 |
+
Hourglass architecture.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
| 284 |
+
super(Hourglass, self).__init__()
|
| 285 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
| 286 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
| 287 |
+
self.out_channels = self.decoder.out_channels
|
| 288 |
+
# self.out_filters = self.decoder.out_filters
|
| 289 |
+
|
| 290 |
+
def forward(self, x, mode = 0):
|
| 291 |
+
return self.decoder(self.encoder(x), mode)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class AntiAliasInterpolation2d(nn.Module):
|
| 295 |
+
"""
|
| 296 |
+
Band-limited downsampling, for better preservation of the input signal.
|
| 297 |
+
"""
|
| 298 |
+
def __init__(self, channels, scale):
|
| 299 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
| 300 |
+
sigma = (1 / scale - 1) / 2
|
| 301 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
| 302 |
+
self.ka = kernel_size // 2
|
| 303 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
| 304 |
+
|
| 305 |
+
kernel_size = [kernel_size, kernel_size]
|
| 306 |
+
sigma = [sigma, sigma]
|
| 307 |
+
# The gaussian kernel is the product of the
|
| 308 |
+
# gaussian function of each dimension.
|
| 309 |
+
kernel = 1
|
| 310 |
+
meshgrids = torch.meshgrid(
|
| 311 |
+
[
|
| 312 |
+
torch.arange(size, dtype=torch.float32)
|
| 313 |
+
for size in kernel_size
|
| 314 |
+
]
|
| 315 |
+
)
|
| 316 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
| 317 |
+
mean = (size - 1) / 2
|
| 318 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
| 319 |
+
|
| 320 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
| 321 |
+
kernel = kernel / torch.sum(kernel)
|
| 322 |
+
# Reshape to depthwise convolutional weight
|
| 323 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
| 324 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
| 325 |
+
|
| 326 |
+
self.register_buffer('weight', kernel)
|
| 327 |
+
self.groups = channels
|
| 328 |
+
self.scale = scale
|
| 329 |
+
|
| 330 |
+
def forward(self, input):
|
| 331 |
+
if self.scale == 1.0:
|
| 332 |
+
return input
|
| 333 |
+
|
| 334 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
| 335 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
| 336 |
+
out = F.interpolate(out, scale_factor=(self.scale, self.scale))
|
| 337 |
+
|
| 338 |
+
return out
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def to_homogeneous(coordinates):
|
| 342 |
+
ones_shape = list(coordinates.shape)
|
| 343 |
+
ones_shape[-1] = 1
|
| 344 |
+
ones = torch.ones(ones_shape).type(coordinates.type())
|
| 345 |
+
|
| 346 |
+
return torch.cat([coordinates, ones], dim=-1)
|
| 347 |
+
|
| 348 |
+
def from_homogeneous(coordinates):
|
| 349 |
+
return coordinates[..., :2] / coordinates[..., 2:3]
|
TPSMM/pkgs/tpsmm.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage import img_as_ubyte
|
| 7 |
+
from skimage.transform import resize
|
| 8 |
+
pwd = os.path.dirname(os.path.realpath(__file__))
|
| 9 |
+
sys.path.insert(1, os.path.join(pwd, ".."))
|
| 10 |
+
|
| 11 |
+
from demo import relative_kp, load_checkpoints
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TPSMM:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.device = torch.device("cuda")
|
| 17 |
+
self.inpainting, self.kp_detector, self.dense_motion_network, self.avd_network = load_checkpoints(
|
| 18 |
+
config_path=os.path.join(pwd, "../config/vox-256.yaml"),
|
| 19 |
+
checkpoint_path=os.path.join(pwd, "../pretrained/vox.pth.tar"),
|
| 20 |
+
device=self.device
|
| 21 |
+
)
|
| 22 |
+
self.kp_driving_initial = None
|
| 23 |
+
|
| 24 |
+
def process_source(self, src_img):
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
| 27 |
+
src_img = resize(src_img, (256, 256))
|
| 28 |
+
source_tensor = torch.tensor(src_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device)
|
| 29 |
+
kp_source = self.kp_detector(source_tensor)
|
| 30 |
+
|
| 31 |
+
return source_tensor, kp_source
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def gen_image(self, driving_img, source_tensor, kp_source):
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
driving_img = cv2.cvtColor(driving_img, cv2.COLOR_BGR2RGB)
|
| 37 |
+
driving_img = resize(driving_img, (256, 256))
|
| 38 |
+
driving_frame = torch.tensor(driving_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device)
|
| 39 |
+
|
| 40 |
+
kp_driving = self.kp_detector(driving_frame)
|
| 41 |
+
if self.kp_driving_initial is None:
|
| 42 |
+
self.kp_driving_initial = kp_driving
|
| 43 |
+
kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
|
| 44 |
+
kp_driving_initial=self.kp_driving_initial)
|
| 45 |
+
dense_motion = self.dense_motion_network(source_image=source_tensor,
|
| 46 |
+
kp_driving=kp_norm,
|
| 47 |
+
kp_source=kp_source, bg_param=None,
|
| 48 |
+
dropout_flag=False)
|
| 49 |
+
out = self.inpainting(source_tensor, dense_motion)
|
| 50 |
+
out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
|
| 51 |
+
out = img_as_ubyte(out)
|
| 52 |
+
out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
|
| 53 |
+
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
tpsmm = TPSMM()
|
| 59 |
+
source_image = cv2.imread(os.path.join(pwd, "../assets/source1.png"))
|
| 60 |
+
cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4")
|
| 61 |
+
|
| 62 |
+
source_tensor, kp_source = tpsmm.process_source(source_image)
|
| 63 |
+
|
| 64 |
+
while True:
|
| 65 |
+
ret, frame = cap.read()
|
| 66 |
+
if frame is None:
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
output = tpsmm.gen_image(frame, source_tensor, kp_source)
|
| 70 |
+
cv2.imshow("output", output)
|
| 71 |
+
key = cv2.waitKey(1) & 0xFF
|
| 72 |
+
if key == ord("q"):
|
| 73 |
+
break
|
| 74 |
+
# cv2.imwrite("./tmp.jpg", output)
|
| 75 |
+
|
| 76 |
+
cv2.destroyAllWindows()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
TPSMM/predict.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.insert(0, "stylegan-encoder")
|
| 4 |
+
import tempfile
|
| 5 |
+
import warnings
|
| 6 |
+
import imageio
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import matplotlib.animation as animation
|
| 10 |
+
from skimage.transform import resize
|
| 11 |
+
from skimage import img_as_ubyte
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
import dlib
|
| 15 |
+
from cog import BasePredictor, Path, Input
|
| 16 |
+
|
| 17 |
+
from demo import load_checkpoints
|
| 18 |
+
from demo import make_animation
|
| 19 |
+
from ffhq_dataset.face_alignment import image_align
|
| 20 |
+
from ffhq_dataset.landmarks_detector import LandmarksDetector
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
warnings.filterwarnings("ignore")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
| 27 |
+
LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Predictor(BasePredictor):
|
| 31 |
+
def setup(self):
|
| 32 |
+
|
| 33 |
+
self.device = torch.device("cuda:0")
|
| 34 |
+
datasets = ["vox", "taichi", "ted", "mgif"]
|
| 35 |
+
(
|
| 36 |
+
self.inpainting,
|
| 37 |
+
self.kp_detector,
|
| 38 |
+
self.dense_motion_network,
|
| 39 |
+
self.avd_network,
|
| 40 |
+
) = ({}, {}, {}, {})
|
| 41 |
+
for d in datasets:
|
| 42 |
+
(
|
| 43 |
+
self.inpainting[d],
|
| 44 |
+
self.kp_detector[d],
|
| 45 |
+
self.dense_motion_network[d],
|
| 46 |
+
self.avd_network[d],
|
| 47 |
+
) = load_checkpoints(
|
| 48 |
+
config_path=f"config/{d}-384.yaml"
|
| 49 |
+
if d == "ted"
|
| 50 |
+
else f"config/{d}-256.yaml",
|
| 51 |
+
checkpoint_path=f"checkpoints/{d}.pth.tar",
|
| 52 |
+
device=self.device,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def predict(
|
| 56 |
+
self,
|
| 57 |
+
source_image: Path = Input(
|
| 58 |
+
description="Input source image.",
|
| 59 |
+
),
|
| 60 |
+
driving_video: Path = Input(
|
| 61 |
+
description="Choose a micromotion.",
|
| 62 |
+
),
|
| 63 |
+
dataset_name: str = Input(
|
| 64 |
+
choices=["vox", "taichi", "ted", "mgif"],
|
| 65 |
+
default="vox",
|
| 66 |
+
description="Choose a dataset.",
|
| 67 |
+
),
|
| 68 |
+
) -> Path:
|
| 69 |
+
|
| 70 |
+
predict_mode = "relative" # ['standard', 'relative', 'avd']
|
| 71 |
+
# find_best_frame = False
|
| 72 |
+
|
| 73 |
+
pixel = 384 if dataset_name == "ted" else 256
|
| 74 |
+
|
| 75 |
+
if dataset_name == "vox":
|
| 76 |
+
# first run face alignment
|
| 77 |
+
align_image(str(source_image), 'aligned.png')
|
| 78 |
+
source_image = imageio.imread('aligned.png')
|
| 79 |
+
else:
|
| 80 |
+
source_image = imageio.imread(str(source_image))
|
| 81 |
+
reader = imageio.get_reader(str(driving_video))
|
| 82 |
+
fps = reader.get_meta_data()["fps"]
|
| 83 |
+
source_image = resize(source_image, (pixel, pixel))[..., :3]
|
| 84 |
+
|
| 85 |
+
driving_video = []
|
| 86 |
+
try:
|
| 87 |
+
for im in reader:
|
| 88 |
+
driving_video.append(im)
|
| 89 |
+
except RuntimeError:
|
| 90 |
+
pass
|
| 91 |
+
reader.close()
|
| 92 |
+
|
| 93 |
+
driving_video = [
|
| 94 |
+
resize(frame, (pixel, pixel))[..., :3] for frame in driving_video
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
inpainting, kp_detector, dense_motion_network, avd_network = (
|
| 98 |
+
self.inpainting[dataset_name],
|
| 99 |
+
self.kp_detector[dataset_name],
|
| 100 |
+
self.dense_motion_network[dataset_name],
|
| 101 |
+
self.avd_network[dataset_name],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
predictions = make_animation(
|
| 105 |
+
source_image,
|
| 106 |
+
driving_video,
|
| 107 |
+
inpainting,
|
| 108 |
+
kp_detector,
|
| 109 |
+
dense_motion_network,
|
| 110 |
+
avd_network,
|
| 111 |
+
device="cuda:0",
|
| 112 |
+
mode=predict_mode,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# save resulting video
|
| 116 |
+
out_path = Path(tempfile.mkdtemp()) / "output.mp4"
|
| 117 |
+
imageio.mimsave(
|
| 118 |
+
str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps
|
| 119 |
+
)
|
| 120 |
+
return out_path
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def align_image(raw_img_path, aligned_face_path):
|
| 124 |
+
for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
|
| 125 |
+
image_align(raw_img_path, aligned_face_path, face_landmarks)
|
TPSMM/pretrained/vox.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52ad8c848e2a1d91b621de96fea83faf57ce3b8c1c06424e317f4df1d3998204
|
| 3 |
+
size 350993469
|
TPSMM/reconstruction.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from logger import Logger, Visualizer
|
| 6 |
+
import numpy as np
|
| 7 |
+
import imageio
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
|
| 11 |
+
png_dir = os.path.join(log_dir, 'reconstruction/png')
|
| 12 |
+
log_dir = os.path.join(log_dir, 'reconstruction')
|
| 13 |
+
|
| 14 |
+
if checkpoint is not None:
|
| 15 |
+
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
|
| 16 |
+
bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
|
| 17 |
+
else:
|
| 18 |
+
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
|
| 19 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
| 20 |
+
|
| 21 |
+
if not os.path.exists(log_dir):
|
| 22 |
+
os.makedirs(log_dir)
|
| 23 |
+
|
| 24 |
+
if not os.path.exists(png_dir):
|
| 25 |
+
os.makedirs(png_dir)
|
| 26 |
+
|
| 27 |
+
loss_list = []
|
| 28 |
+
|
| 29 |
+
inpainting_network.eval()
|
| 30 |
+
kp_detector.eval()
|
| 31 |
+
dense_motion_network.eval()
|
| 32 |
+
if bg_predictor:
|
| 33 |
+
bg_predictor.eval()
|
| 34 |
+
|
| 35 |
+
for it, x in tqdm(enumerate(dataloader)):
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
predictions = []
|
| 38 |
+
visualizations = []
|
| 39 |
+
if torch.cuda.is_available():
|
| 40 |
+
x['video'] = x['video'].cuda()
|
| 41 |
+
kp_source = kp_detector(x['video'][:, :, 0])
|
| 42 |
+
for frame_idx in range(x['video'].shape[2]):
|
| 43 |
+
source = x['video'][:, :, 0]
|
| 44 |
+
driving = x['video'][:, :, frame_idx]
|
| 45 |
+
kp_driving = kp_detector(driving)
|
| 46 |
+
bg_params = None
|
| 47 |
+
if bg_predictor:
|
| 48 |
+
bg_params = bg_predictor(source, driving)
|
| 49 |
+
|
| 50 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
|
| 51 |
+
kp_source=kp_source, bg_param = bg_params,
|
| 52 |
+
dropout_flag = False)
|
| 53 |
+
out = inpainting_network(source, dense_motion)
|
| 54 |
+
out['kp_source'] = kp_source
|
| 55 |
+
out['kp_driving'] = kp_driving
|
| 56 |
+
|
| 57 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
| 58 |
+
|
| 59 |
+
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
|
| 60 |
+
driving=driving, out=out)
|
| 61 |
+
visualizations.append(visualization)
|
| 62 |
+
loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
|
| 63 |
+
|
| 64 |
+
loss_list.append(loss)
|
| 65 |
+
# print(np.mean(loss_list))
|
| 66 |
+
predictions = np.concatenate(predictions, axis=1)
|
| 67 |
+
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
|
| 68 |
+
|
| 69 |
+
print("Reconstruction loss: %s" % np.mean(loss_list))
|
TPSMM/requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cffi==1.14.6
|
| 2 |
+
cycler==0.10.0
|
| 3 |
+
decorator==5.1.0
|
| 4 |
+
face-alignment==1.3.5
|
| 5 |
+
imageio==2.9.0
|
| 6 |
+
imageio-ffmpeg==0.4.5
|
| 7 |
+
kiwisolver==1.3.2
|
| 8 |
+
matplotlib==3.4.3
|
| 9 |
+
networkx==2.6.3
|
| 10 |
+
numpy==1.20.3
|
| 11 |
+
pandas==1.3.3
|
| 12 |
+
Pillow==8.3.2
|
| 13 |
+
pycparser==2.20
|
| 14 |
+
pyparsing==2.4.7
|
| 15 |
+
python-dateutil==2.8.2
|
| 16 |
+
pytz==2021.1
|
| 17 |
+
PyWavelets==1.1.1
|
| 18 |
+
PyYAML==5.4.1
|
| 19 |
+
scikit-image==0.18.3
|
| 20 |
+
scikit-learn==1.0
|
| 21 |
+
scipy==1.7.1
|
| 22 |
+
six==1.16.0
|
| 23 |
+
torch==1.10.0+cu113
|
| 24 |
+
torchvision==0.11.0+cu113
|
| 25 |
+
tqdm==4.62.3
|
TPSMM/run.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use('Agg')
|
| 3 |
+
|
| 4 |
+
import os, sys
|
| 5 |
+
import yaml
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
from time import gmtime, strftime
|
| 8 |
+
from shutil import copy
|
| 9 |
+
from frames_dataset import FramesDataset
|
| 10 |
+
|
| 11 |
+
from modules.inpainting_network import InpaintingNetwork
|
| 12 |
+
from modules.keypoint_detector import KPDetector
|
| 13 |
+
from modules.bg_motion_predictor import BGMotionPredictor
|
| 14 |
+
from modules.dense_motion import DenseMotionNetwork
|
| 15 |
+
from modules.avd_network import AVDNetwork
|
| 16 |
+
import torch
|
| 17 |
+
from train import train
|
| 18 |
+
from train_avd import train_avd
|
| 19 |
+
from reconstruction import reconstruction
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
|
| 25 |
+
if sys.version_info[0] < 3:
|
| 26 |
+
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
|
| 27 |
+
|
| 28 |
+
parser = ArgumentParser()
|
| 29 |
+
parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
|
| 30 |
+
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
|
| 31 |
+
parser.add_argument("--log_dir", default='log', help="path to log into")
|
| 32 |
+
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
|
| 33 |
+
parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))),
|
| 34 |
+
help="Names of the devices comma separated.")
|
| 35 |
+
|
| 36 |
+
opt = parser.parse_args()
|
| 37 |
+
with open(opt.config) as f:
|
| 38 |
+
config = yaml.load(f)
|
| 39 |
+
|
| 40 |
+
if opt.checkpoint is not None:
|
| 41 |
+
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
|
| 42 |
+
else:
|
| 43 |
+
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
|
| 44 |
+
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
|
| 45 |
+
|
| 46 |
+
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
|
| 47 |
+
**config['model_params']['common_params'])
|
| 48 |
+
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
cuda_device = torch.device('cuda:'+str(opt.device_ids[0]))
|
| 51 |
+
inpainting.to(cuda_device)
|
| 52 |
+
|
| 53 |
+
kp_detector = KPDetector(**config['model_params']['common_params'])
|
| 54 |
+
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
|
| 55 |
+
**config['model_params']['dense_motion_params'])
|
| 56 |
+
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
kp_detector.to(opt.device_ids[0])
|
| 59 |
+
dense_motion_network.to(opt.device_ids[0])
|
| 60 |
+
|
| 61 |
+
bg_predictor = None
|
| 62 |
+
if (config['model_params']['common_params']['bg']):
|
| 63 |
+
bg_predictor = BGMotionPredictor()
|
| 64 |
+
if torch.cuda.is_available():
|
| 65 |
+
bg_predictor.to(opt.device_ids[0])
|
| 66 |
+
|
| 67 |
+
avd_network = None
|
| 68 |
+
if opt.mode == "train_avd":
|
| 69 |
+
avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
|
| 70 |
+
**config['model_params']['avd_network_params'])
|
| 71 |
+
if torch.cuda.is_available():
|
| 72 |
+
avd_network.to(opt.device_ids[0])
|
| 73 |
+
|
| 74 |
+
dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params'])
|
| 75 |
+
|
| 76 |
+
if not os.path.exists(log_dir):
|
| 77 |
+
os.makedirs(log_dir)
|
| 78 |
+
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
|
| 79 |
+
copy(opt.config, log_dir)
|
| 80 |
+
|
| 81 |
+
if opt.mode == 'train':
|
| 82 |
+
print("Training...")
|
| 83 |
+
train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
|
| 84 |
+
elif opt.mode == 'train_avd':
|
| 85 |
+
print("Training Animation via Disentaglement...")
|
| 86 |
+
train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset)
|
| 87 |
+
elif opt.mode == 'reconstruction':
|
| 88 |
+
print("Reconstruction...")
|
| 89 |
+
reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
|
TPSMM/tmp.jpg
ADDED
|
TPSMM/tmp.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4")
|
| 5 |
+
while True:
|
| 6 |
+
ret, frame = cap.read()
|
| 7 |
+
if frame is None:
|
| 8 |
+
break
|
| 9 |
+
cv2.imshow("output", frame)
|
| 10 |
+
key = cv2.waitKey(1) & 0xff
|
| 11 |
+
if key == ord("q"):
|
| 12 |
+
break
|
| 13 |
+
|
| 14 |
+
cv2.destroyAllWindows()
|
TPSMM/train.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import trange
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from logger import Logger
|
| 5 |
+
from modules.model import GeneratorFullModel
|
| 6 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
| 7 |
+
from torch.nn.utils import clip_grad_norm_
|
| 8 |
+
from frames_dataset import DatasetRepeater
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
|
| 12 |
+
train_params = config['train_params']
|
| 13 |
+
optimizer = torch.optim.Adam(
|
| 14 |
+
[{'params': list(inpainting_network.parameters()) +
|
| 15 |
+
list(dense_motion_network.parameters()) +
|
| 16 |
+
list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
|
| 17 |
+
|
| 18 |
+
optimizer_bg_predictor = None
|
| 19 |
+
if bg_predictor:
|
| 20 |
+
optimizer_bg_predictor = torch.optim.Adam(
|
| 21 |
+
[{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}],
|
| 22 |
+
lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
|
| 23 |
+
|
| 24 |
+
if checkpoint is not None:
|
| 25 |
+
start_epoch = Logger.load_cpk(
|
| 26 |
+
checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network,
|
| 27 |
+
kp_detector = kp_detector, bg_predictor = bg_predictor,
|
| 28 |
+
optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor)
|
| 29 |
+
print('load success:', start_epoch)
|
| 30 |
+
start_epoch += 1
|
| 31 |
+
else:
|
| 32 |
+
start_epoch = 0
|
| 33 |
+
|
| 34 |
+
scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1,
|
| 35 |
+
last_epoch=start_epoch - 1)
|
| 36 |
+
if bg_predictor:
|
| 37 |
+
scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'],
|
| 38 |
+
gamma=0.1, last_epoch=start_epoch - 1)
|
| 39 |
+
|
| 40 |
+
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
|
| 41 |
+
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
|
| 42 |
+
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
|
| 43 |
+
num_workers=train_params['dataloader_workers'], drop_last=True)
|
| 44 |
+
|
| 45 |
+
generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params)
|
| 46 |
+
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
generator_full = torch.nn.DataParallel(generator_full).cuda()
|
| 49 |
+
|
| 50 |
+
bg_start = train_params['bg_start']
|
| 51 |
+
|
| 52 |
+
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
|
| 53 |
+
checkpoint_freq=train_params['checkpoint_freq']) as logger:
|
| 54 |
+
for epoch in trange(start_epoch, train_params['num_epochs']):
|
| 55 |
+
for x in dataloader:
|
| 56 |
+
if(torch.cuda.is_available()):
|
| 57 |
+
x['driving'] = x['driving'].cuda()
|
| 58 |
+
x['source'] = x['source'].cuda()
|
| 59 |
+
|
| 60 |
+
losses_generator, generated = generator_full(x, epoch)
|
| 61 |
+
loss_values = [val.mean() for val in losses_generator.values()]
|
| 62 |
+
loss = sum(loss_values)
|
| 63 |
+
loss.backward()
|
| 64 |
+
|
| 65 |
+
clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf)
|
| 66 |
+
clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf)
|
| 67 |
+
if bg_predictor and epoch>=bg_start:
|
| 68 |
+
clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf)
|
| 69 |
+
|
| 70 |
+
optimizer.step()
|
| 71 |
+
optimizer.zero_grad()
|
| 72 |
+
if bg_predictor and epoch>=bg_start:
|
| 73 |
+
optimizer_bg_predictor.step()
|
| 74 |
+
optimizer_bg_predictor.zero_grad()
|
| 75 |
+
|
| 76 |
+
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
|
| 77 |
+
logger.log_iter(losses=losses)
|
| 78 |
+
|
| 79 |
+
scheduler_optimizer.step()
|
| 80 |
+
if bg_predictor:
|
| 81 |
+
scheduler_bg_predictor.step()
|
| 82 |
+
|
| 83 |
+
model_save = {
|
| 84 |
+
'inpainting_network': inpainting_network,
|
| 85 |
+
'dense_motion_network': dense_motion_network,
|
| 86 |
+
'kp_detector': kp_detector,
|
| 87 |
+
'optimizer': optimizer,
|
| 88 |
+
}
|
| 89 |
+
if bg_predictor and epoch>=bg_start:
|
| 90 |
+
model_save['bg_predictor'] = bg_predictor
|
| 91 |
+
model_save['optimizer_bg_predictor'] = optimizer_bg_predictor
|
| 92 |
+
|
| 93 |
+
logger.log_epoch(epoch, model_save, inp=x, out=generated)
|
| 94 |
+
|
TPSMM/train_avd.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import trange
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from logger import Logger
|
| 5 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
| 6 |
+
from frames_dataset import DatasetRepeater
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def random_scale(kp_params, scale):
|
| 10 |
+
theta = torch.rand(kp_params['fg_kp'].shape[0], 2) * (2 * scale) + (1 - scale)
|
| 11 |
+
theta = torch.diag_embed(theta).unsqueeze(1).type(kp_params['fg_kp'].type())
|
| 12 |
+
new_kp_params = {'fg_kp': torch.matmul(theta, kp_params['fg_kp'].unsqueeze(-1)).squeeze(-1)}
|
| 13 |
+
return new_kp_params
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def train_avd(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network,
|
| 17 |
+
avd_network, checkpoint, log_dir, dataset):
|
| 18 |
+
train_params = config['train_avd_params']
|
| 19 |
+
|
| 20 |
+
optimizer = torch.optim.Adam(avd_network.parameters(), lr=train_params['lr'], betas=(0.5, 0.999))
|
| 21 |
+
|
| 22 |
+
if checkpoint is not None:
|
| 23 |
+
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
|
| 24 |
+
bg_predictor=bg_predictor, avd_network=avd_network,
|
| 25 |
+
dense_motion_network= dense_motion_network,optimizer_avd=optimizer)
|
| 26 |
+
start_epoch = 0
|
| 27 |
+
else:
|
| 28 |
+
raise AttributeError("Checkpoint should be specified for mode='train_avd'.")
|
| 29 |
+
|
| 30 |
+
scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1)
|
| 31 |
+
|
| 32 |
+
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
|
| 33 |
+
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
|
| 34 |
+
|
| 35 |
+
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
|
| 36 |
+
num_workers=train_params['dataloader_workers'], drop_last=True)
|
| 37 |
+
|
| 38 |
+
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
|
| 39 |
+
checkpoint_freq=train_params['checkpoint_freq']) as logger:
|
| 40 |
+
for epoch in trange(start_epoch, train_params['num_epochs']):
|
| 41 |
+
avd_network.train()
|
| 42 |
+
for x in dataloader:
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
kp_source = kp_detector(x['source'].cuda())
|
| 45 |
+
kp_driving_gt = kp_detector(x['driving'].cuda())
|
| 46 |
+
kp_driving_random = random_scale(kp_driving_gt, scale=train_params['random_scale'])
|
| 47 |
+
rec = avd_network(kp_source, kp_driving_random)
|
| 48 |
+
|
| 49 |
+
reconstruction_kp = train_params['lambda_shift'] * \
|
| 50 |
+
torch.abs(kp_driving_gt['fg_kp'] - rec['fg_kp']).mean()
|
| 51 |
+
|
| 52 |
+
loss_dict = {'rec_kp': reconstruction_kp}
|
| 53 |
+
loss = reconstruction_kp
|
| 54 |
+
|
| 55 |
+
loss.backward()
|
| 56 |
+
optimizer.step()
|
| 57 |
+
optimizer.zero_grad()
|
| 58 |
+
|
| 59 |
+
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in loss_dict.items()}
|
| 60 |
+
logger.log_iter(losses=losses)
|
| 61 |
+
|
| 62 |
+
# Visualization
|
| 63 |
+
avd_network.eval()
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
source = x['source'][:6].cuda()
|
| 66 |
+
driving = torch.cat([x['driving'][[0, 1]].cuda(), source[[2, 3, 2, 1]]], dim=0)
|
| 67 |
+
kp_source = kp_detector(source)
|
| 68 |
+
kp_driving = kp_detector(driving)
|
| 69 |
+
|
| 70 |
+
out = avd_network(kp_source, kp_driving)
|
| 71 |
+
kp_driving = out
|
| 72 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
|
| 73 |
+
kp_source=kp_source)
|
| 74 |
+
generated = inpainting_network(source, dense_motion)
|
| 75 |
+
|
| 76 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
| 77 |
+
|
| 78 |
+
scheduler.step(epoch)
|
| 79 |
+
model_save = {
|
| 80 |
+
'inpainting_network': inpainting_network,
|
| 81 |
+
'dense_motion_network': dense_motion_network,
|
| 82 |
+
'kp_detector': kp_detector,
|
| 83 |
+
'avd_network': avd_network,
|
| 84 |
+
'optimizer_avd': optimizer
|
| 85 |
+
}
|
| 86 |
+
if bg_predictor :
|
| 87 |
+
model_save['bg_predictor'] = bg_predictor
|
| 88 |
+
|
| 89 |
+
logger.log_epoch(epoch, model_save,
|
| 90 |
+
inp={'source': source, 'driving': driving},
|
| 91 |
+
out=generated)
|
app.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import av
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import streamlit as st
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from streamlit_webrtc import WebRtcMode, webrtc_streamer
|
| 8 |
+
|
| 9 |
+
sys.path.insert(1, "./retinaface")
|
| 10 |
+
sys.path.insert(1, "./TPSMM/pkgs")
|
| 11 |
+
from tpsmm import TPSMM
|
| 12 |
+
from detect import Detect
|
| 13 |
+
from turn import get_ice_servers
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_roi_box_from_bbox(bbox, shape):
|
| 17 |
+
img_h, img_w = shape[:2]
|
| 18 |
+
left, top, right, bottom = bbox[:4]
|
| 19 |
+
old_size = (right - left + bottom - top) / 2
|
| 20 |
+
center_x = right - (right - left) / 2.0
|
| 21 |
+
center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14
|
| 22 |
+
|
| 23 |
+
size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0)
|
| 24 |
+
|
| 25 |
+
roi_box = [0] * 4
|
| 26 |
+
roi_box[0] = center_x - size / 2
|
| 27 |
+
roi_box[1] = center_y - size / 2
|
| 28 |
+
roi_box[2] = roi_box[0] + size
|
| 29 |
+
roi_box[3] = roi_box[1] + size
|
| 30 |
+
|
| 31 |
+
return roi_box
|
| 32 |
+
|
| 33 |
+
cache_key = "retinaface"
|
| 34 |
+
if cache_key in st.session_state:
|
| 35 |
+
detector = st.session_state[cache_key]
|
| 36 |
+
else:
|
| 37 |
+
detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864))
|
| 38 |
+
st.session_state[cache_key] = detector
|
| 39 |
+
|
| 40 |
+
cache_key = "tpsmm"
|
| 41 |
+
if cache_key in st.session_state:
|
| 42 |
+
generator = st.session_state[cache_key]
|
| 43 |
+
else:
|
| 44 |
+
generator = TPSMM()
|
| 45 |
+
st.session_state[cache_key] = generator
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@st.cache_resource # type: ignore
|
| 49 |
+
def get_images():
|
| 50 |
+
images = [
|
| 51 |
+
cv2.imread("assets/0.jpg"),
|
| 52 |
+
cv2.imread("assets/1.jpg"),
|
| 53 |
+
cv2.imread("assets/2.jpg"),
|
| 54 |
+
cv2.imread("assets/3.jpg"),
|
| 55 |
+
]
|
| 56 |
+
item_list = [str(i) for i in range(len(images))]
|
| 57 |
+
images = [generator.process_source(src_img) for src_img in images]
|
| 58 |
+
|
| 59 |
+
return dict(zip(item_list, images))
|
| 60 |
+
images = get_images()
|
| 61 |
+
user_option = st.selectbox("Choose an item", list(images.keys()))
|
| 62 |
+
|
| 63 |
+
uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg'])
|
| 64 |
+
@st.cache_resource
|
| 65 |
+
def process_file(uploaded_file):
|
| 66 |
+
img = Image.open(uploaded_file)
|
| 67 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 68 |
+
dets = detector(img)
|
| 69 |
+
for i, b in enumerate(dets):
|
| 70 |
+
bbox = parse_roi_box_from_bbox(b[:4], img.shape)
|
| 71 |
+
bbox = [int(i) for i in bbox]
|
| 72 |
+
|
| 73 |
+
face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
|
| 74 |
+
# cv2.imwrite("./tmp.jpg", face_img)
|
| 75 |
+
return generator.process_source(face_img)
|
| 76 |
+
|
| 77 |
+
return None
|
| 78 |
+
if uploaded_file is not None:
|
| 79 |
+
uploaded_file = process_file(uploaded_file)
|
| 80 |
+
|
| 81 |
+
def callback(frame: av.VideoFrame) -> av.VideoFrame:
|
| 82 |
+
img = frame.to_ndarray(format="bgr24")
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
dets = detector(img)
|
| 86 |
+
output = None
|
| 87 |
+
for i, b in enumerate(dets):
|
| 88 |
+
text = "{:.4f}".format(b[4])
|
| 89 |
+
b = b.astype(np.int32)
|
| 90 |
+
cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
|
| 91 |
+
bbox = parse_roi_box_from_bbox(b[:4], img.shape)
|
| 92 |
+
bbox = [int(i) for i in bbox]
|
| 93 |
+
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
|
| 94 |
+
|
| 95 |
+
face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
|
| 96 |
+
if uploaded_file is None:
|
| 97 |
+
source_tensor, kp_source = images[user_option]
|
| 98 |
+
else:
|
| 99 |
+
source_tensor, kp_source = uploaded_file
|
| 100 |
+
output = generator.gen_image(face_img, source_tensor, kp_source)
|
| 101 |
+
|
| 102 |
+
landm = b[5:15]
|
| 103 |
+
landm = landm.reshape((5, 2))
|
| 104 |
+
cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
|
| 105 |
+
cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
|
| 106 |
+
cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
|
| 107 |
+
cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
|
| 108 |
+
cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
|
| 109 |
+
|
| 110 |
+
if output is not None:
|
| 111 |
+
img[:256, :256] = output
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(e)
|
| 114 |
+
|
| 115 |
+
return av.VideoFrame.from_ndarray(img, format="bgr24")
|
| 116 |
+
|
| 117 |
+
webrtc_streamer(
|
| 118 |
+
key="sample",
|
| 119 |
+
rtc_configuration={"iceServers": get_ice_servers()},
|
| 120 |
+
video_frame_callback=callback,
|
| 121 |
+
media_stream_constraints={"video": True, "audio": False},
|
| 122 |
+
)
|
assets/0.jpg
ADDED
|
assets/1.jpg
ADDED
|
assets/2.jpg
ADDED
|
assets/3.jpg
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit-webrtc
|
| 2 |
+
twilio
|
| 3 |
+
altair<5
|
| 4 |
+
numpy==1.23.1
|
| 5 |
+
opencv-python==4.8.0.74
|
| 6 |
+
imutils
|
| 7 |
+
scikit-image==0.21.0
|
| 8 |
+
matplotlib==3.7.1
|
| 9 |
+
pyaml==23.5.9
|
| 10 |
+
tqdm
|
| 11 |
+
torch
|
| 12 |
+
torchvision
|
retinaface/change_batch_onnx.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import onnx
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
model = onnx.load('weights/faceDetector_243_432_b1_sim.onnx')
|
| 5 |
+
|
| 6 |
+
# # for fixed batchsize
|
| 7 |
+
# model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 32
|
| 8 |
+
# model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 32
|
| 9 |
+
# model.graph.output[1].type.tensor_type.shape.dim[0].dim_value = 32
|
| 10 |
+
# model.graph.output[2].type.tensor_type.shape.dim[0].dim_value = 32
|
| 11 |
+
# onnx.save(model, 'faceDetector_640_b32.onnx')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'batch' # for dynamic batchsize
|
| 15 |
+
model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 16 |
+
model.graph.output[1].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 17 |
+
model.graph.output[2].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 18 |
+
model.graph.output[3].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 19 |
+
model.graph.output[4].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 20 |
+
model.graph.output[5].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 21 |
+
model.graph.output[6].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 22 |
+
model.graph.output[7].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 23 |
+
model.graph.output[8].type.tensor_type.shape.dim[0].dim_param = 'batch'
|
| 24 |
+
onnx.save(model, 'weights/faceDetector_243_432_batch_sim.onnx')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
####################################################
|
| 28 |
+
# SHow model onnx
|
| 29 |
+
|
| 30 |
+
# import onnxruntime as rt
|
| 31 |
+
# ort_session = rt.InferenceSession("faceDetector_180_320_batch_sim.onnx")
|
| 32 |
+
# print("====INPUT====")
|
| 33 |
+
# for i in ort_session.get_inputs():
|
| 34 |
+
# print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type))
|
| 35 |
+
# print("====OUTPUT====")
|
| 36 |
+
# for i in ort_session.get_outputs():
|
| 37 |
+
# print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type))
|
| 38 |
+
|
| 39 |
+
# import numpy as np
|
| 40 |
+
# input_name = ort_session.get_inputs()[0].name
|
| 41 |
+
# img = np.random.randn(4, 3, 180, 320).astype(np.float32)
|
| 42 |
+
# data = ort_session.run(None, {input_name: img})
|
| 43 |
+
# print("Done")
|
retinaface/convert_to_onnx.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python convert_to_onnx.py --network mobile0.25 --trained_model weights/mobilenet0.25_Final.pth
|
| 2 |
+
from __future__ import print_function
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import torch
|
| 6 |
+
import torch.backends.cudnn as cudnn
|
| 7 |
+
import numpy as np
|
| 8 |
+
from data import cfg_mnet, cfg_slim, cfg_rfb
|
| 9 |
+
from layers.functions.prior_box import PriorBox
|
| 10 |
+
from utils.nms.py_cpu_nms import py_cpu_nms
|
| 11 |
+
import cv2
|
| 12 |
+
from models.retinaface import RetinaFace
|
| 13 |
+
from models.net_slim import Slim
|
| 14 |
+
from models.net_rfb import RFB
|
| 15 |
+
from utils.box_utils import decode, decode_landm
|
| 16 |
+
from utils.timer import Timer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser(description='Test')
|
| 20 |
+
parser.add_argument('-m', '--trained_model', default='./weights/RBF_Final.pth',
|
| 21 |
+
type=str, help='Trained state_dict file path to open')
|
| 22 |
+
parser.add_argument('--network', default='RFB', help='Backbone network mobile0.25 or slim or RFB')
|
| 23 |
+
parser.add_argument('--long_side', default=320, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)')
|
| 24 |
+
parser.add_argument('--cpu', action="store_true", help='Use cpu inference')
|
| 25 |
+
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check_keys(model, pretrained_state_dict):
|
| 30 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
| 31 |
+
model_keys = set(model.state_dict().keys())
|
| 32 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
| 33 |
+
unused_pretrained_keys = ckpt_keys - model_keys
|
| 34 |
+
missing_keys = model_keys - ckpt_keys
|
| 35 |
+
print('Missing keys:{}'.format(len(missing_keys)))
|
| 36 |
+
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
|
| 37 |
+
print('Used keys:{}'.format(len(used_pretrained_keys)))
|
| 38 |
+
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def remove_prefix(state_dict, prefix):
|
| 43 |
+
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
|
| 44 |
+
print('remove prefix \'{}\''.format(prefix))
|
| 45 |
+
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
| 46 |
+
return {f(key): value for key, value in state_dict.items()}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_model(model, pretrained_path, load_to_cpu):
|
| 50 |
+
print('Loading pretrained model from {}'.format(pretrained_path))
|
| 51 |
+
if load_to_cpu:
|
| 52 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
|
| 53 |
+
else:
|
| 54 |
+
device = torch.cuda.current_device()
|
| 55 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
|
| 56 |
+
if "state_dict" in pretrained_dict.keys():
|
| 57 |
+
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
|
| 58 |
+
else:
|
| 59 |
+
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
|
| 60 |
+
check_keys(model, pretrained_dict)
|
| 61 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
torch.set_grad_enabled(False)
|
| 67 |
+
|
| 68 |
+
cfg = None
|
| 69 |
+
net = None
|
| 70 |
+
# long_side = int(args.long_side)
|
| 71 |
+
net_inshape = (243, 432)
|
| 72 |
+
device = torch.device("cpu" if args.cpu else "cuda")
|
| 73 |
+
print(device)
|
| 74 |
+
if args.network == "mobile0.25":
|
| 75 |
+
cfg = cfg_mnet
|
| 76 |
+
# net_inshape = (long_side, long_side) # h, w
|
| 77 |
+
priorbox = PriorBox(cfg, image_size=net_inshape)
|
| 78 |
+
priors = priorbox.forward()
|
| 79 |
+
prior_data = priors.to(device)
|
| 80 |
+
net = RetinaFace(cfg=cfg, phase='test')
|
| 81 |
+
elif args.network == "slim":
|
| 82 |
+
cfg = cfg_slim
|
| 83 |
+
net = Slim(cfg = cfg, phase = 'test')
|
| 84 |
+
elif args.network == "RFB":
|
| 85 |
+
cfg = cfg_rfb
|
| 86 |
+
net = RFB(cfg = cfg, phase = 'test')
|
| 87 |
+
else:
|
| 88 |
+
print("Don't support network!")
|
| 89 |
+
exit(0)
|
| 90 |
+
|
| 91 |
+
# load weight
|
| 92 |
+
net = load_model(net, args.trained_model, args.cpu)
|
| 93 |
+
net.eval()
|
| 94 |
+
print('Finished loading model!')
|
| 95 |
+
print(net)
|
| 96 |
+
net = net.to(device)
|
| 97 |
+
|
| 98 |
+
##################export###############
|
| 99 |
+
output_onnx = f'weights/faceDetector_{net_inshape[0]}_{net_inshape[1]}_b1.onnx'
|
| 100 |
+
print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
|
| 101 |
+
input_names = ['input_1']
|
| 102 |
+
output_names = ['box_1', 'box_2', 'box_3']
|
| 103 |
+
|
| 104 |
+
# import torch.onnx.symbolic_opset9 as onnx_symbolic
|
| 105 |
+
# def upsample_nearest2d(g, input, output_size, *args):
|
| 106 |
+
# # Currently, TRT 5.1/6.0/7.0 ONNX Parser does not support all ONNX ops
|
| 107 |
+
# # needed to support dynamic upsampling ONNX forumlation
|
| 108 |
+
# # Here we hardcode scale=2 as a temporary workaround
|
| 109 |
+
# scales = g.op("Constant", value_t=torch.tensor([1., 1., 2., 2.]))
|
| 110 |
+
# return g.op("Resize", input, scales, mode_s="nearest")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# onnx_symbolic.upsample_nearest2d = upsample_nearest2d
|
| 114 |
+
|
| 115 |
+
# import io
|
| 116 |
+
# onnx_bytes = io.BytesIO()
|
| 117 |
+
# zero_input = torch.zeros([1, 3, net_inshape[0], net_inshape[1]]).cuda()
|
| 118 |
+
# dynamic_axes = {input_names[0]: {0:'batch'}}
|
| 119 |
+
# for _, name in enumerate(output_names):
|
| 120 |
+
# dynamic_axes[name] = dynamic_axes[input_names[0]]
|
| 121 |
+
# extra_args = {'opset_version': 10, 'verbose': False,
|
| 122 |
+
# 'input_names': input_names, 'output_names': output_names,
|
| 123 |
+
# 'dynamic_axes': dynamic_axes}
|
| 124 |
+
# torch.onnx.export(net, zero_input, onnx_bytes, **extra_args)
|
| 125 |
+
# with open(output_onnx, 'wb') as out:
|
| 126 |
+
# out.write(onnx_bytes.getvalue())
|
| 127 |
+
|
| 128 |
+
inputs = torch.randn(1, 3, net_inshape[0], net_inshape[1]).to(device)
|
| 129 |
+
torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False, opset_version=9,
|
| 130 |
+
input_names=input_names, output_names=output_names)
|
| 131 |
+
################end###############
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
retinaface/data/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .wider_face import WiderFaceDetection, detection_collate
|
| 2 |
+
from .data_augment import *
|
| 3 |
+
from .config import *
|
retinaface/data/config.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
cfg_mnet = {
|
| 3 |
+
'name': 'mobilenet0.25',
|
| 4 |
+
'min_sizes': [[10, 20], [32, 64], [128, 256]],
|
| 5 |
+
'steps': [8, 16, 32],
|
| 6 |
+
'variance': [0.1, 0.2],
|
| 7 |
+
'clip': False,
|
| 8 |
+
'loc_weight': 2.0,
|
| 9 |
+
'gpu_train': True,
|
| 10 |
+
'batch_size': 32,
|
| 11 |
+
'ngpu': 1,
|
| 12 |
+
'epoch': 250,
|
| 13 |
+
'decay1': 190,
|
| 14 |
+
'decay2': 220,
|
| 15 |
+
'image_size': 300,
|
| 16 |
+
"net_inshape": (320, 320), # h, w
|
| 17 |
+
'pretrain': False,
|
| 18 |
+
'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
|
| 19 |
+
'in_channel': 32,
|
| 20 |
+
'out_channel': 64
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
cfg_slim = {
|
| 24 |
+
'name': 'slim',
|
| 25 |
+
'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
|
| 26 |
+
'steps': [8, 16, 32, 64],
|
| 27 |
+
'variance': [0.1, 0.2],
|
| 28 |
+
'clip': False,
|
| 29 |
+
'loc_weight': 2.0,
|
| 30 |
+
'gpu_train': True,
|
| 31 |
+
'batch_size': 32,
|
| 32 |
+
'ngpu': 1,
|
| 33 |
+
'epoch': 250,
|
| 34 |
+
'decay1': 190,
|
| 35 |
+
'decay2': 220,
|
| 36 |
+
'image_size': 300
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
cfg_rfb = {
|
| 40 |
+
'name': 'RFB',
|
| 41 |
+
'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
|
| 42 |
+
'steps': [8, 16, 32, 64],
|
| 43 |
+
'variance': [0.1, 0.2],
|
| 44 |
+
'clip': False,
|
| 45 |
+
'loc_weight': 2.0,
|
| 46 |
+
'gpu_train': True,
|
| 47 |
+
'batch_size': 32,
|
| 48 |
+
'ngpu': 1,
|
| 49 |
+
'epoch': 250,
|
| 50 |
+
'decay1': 190,
|
| 51 |
+
'decay2': 220,
|
| 52 |
+
'image_size': 300
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
retinaface/data/data_augment.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
from utils.box_utils import matrix_iof
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _crop(image, boxes, labels, landm, img_dim):
|
| 8 |
+
height, width, _ = image.shape
|
| 9 |
+
pad_image_flag = True
|
| 10 |
+
|
| 11 |
+
for _ in range(250):
|
| 12 |
+
if random.uniform(0, 1) <= 0.2:
|
| 13 |
+
scale = 1.0
|
| 14 |
+
else:
|
| 15 |
+
scale = random.uniform(0.3, 1.0)
|
| 16 |
+
# PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
|
| 17 |
+
# scale = random.choice(PRE_SCALES)
|
| 18 |
+
short_side = min(width, height)
|
| 19 |
+
w = int(scale * short_side)
|
| 20 |
+
h = w
|
| 21 |
+
|
| 22 |
+
if width == w:
|
| 23 |
+
l = 0
|
| 24 |
+
else:
|
| 25 |
+
l = random.randrange(width - w)
|
| 26 |
+
if height == h:
|
| 27 |
+
t = 0
|
| 28 |
+
else:
|
| 29 |
+
t = random.randrange(height - h)
|
| 30 |
+
roi = np.array((l, t, l + w, t + h))
|
| 31 |
+
|
| 32 |
+
value = matrix_iof(boxes, roi[np.newaxis])
|
| 33 |
+
flag = (value >= 1)
|
| 34 |
+
if not flag.any():
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
centers = (boxes[:, :2] + boxes[:, 2:]) / 2
|
| 38 |
+
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
|
| 39 |
+
boxes_t = boxes[mask_a].copy()
|
| 40 |
+
labels_t = labels[mask_a].copy()
|
| 41 |
+
landms_t = landm[mask_a].copy()
|
| 42 |
+
landms_t = landms_t.reshape([-1, 5, 2])
|
| 43 |
+
|
| 44 |
+
if boxes_t.shape[0] == 0:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
|
| 48 |
+
|
| 49 |
+
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
|
| 50 |
+
boxes_t[:, :2] -= roi[:2]
|
| 51 |
+
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
|
| 52 |
+
boxes_t[:, 2:] -= roi[:2]
|
| 53 |
+
|
| 54 |
+
# landm
|
| 55 |
+
landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
|
| 56 |
+
landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
|
| 57 |
+
landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
|
| 58 |
+
landms_t = landms_t.reshape([-1, 10])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# make sure that the cropped image contains at least one face > 16 pixel at training image scale
|
| 62 |
+
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
|
| 63 |
+
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
|
| 64 |
+
mask_b = np.minimum(b_w_t, b_h_t) > 5
|
| 65 |
+
boxes_t = boxes_t[mask_b]
|
| 66 |
+
labels_t = labels_t[mask_b]
|
| 67 |
+
landms_t = landms_t[mask_b]
|
| 68 |
+
|
| 69 |
+
if boxes_t.shape[0] == 0:
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
pad_image_flag = False
|
| 73 |
+
|
| 74 |
+
return image_t, boxes_t, labels_t, landms_t, pad_image_flag
|
| 75 |
+
return image, boxes, labels, landm, pad_image_flag
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _distort(image):
|
| 79 |
+
|
| 80 |
+
def _convert(image, alpha=1, beta=0):
|
| 81 |
+
tmp = image.astype(float) * alpha + beta
|
| 82 |
+
tmp[tmp < 0] = 0
|
| 83 |
+
tmp[tmp > 255] = 255
|
| 84 |
+
image[:] = tmp
|
| 85 |
+
|
| 86 |
+
image = image.copy()
|
| 87 |
+
|
| 88 |
+
if random.randrange(2):
|
| 89 |
+
|
| 90 |
+
#brightness distortion
|
| 91 |
+
if random.randrange(2):
|
| 92 |
+
_convert(image, beta=random.uniform(-32, 32))
|
| 93 |
+
|
| 94 |
+
#contrast distortion
|
| 95 |
+
if random.randrange(2):
|
| 96 |
+
_convert(image, alpha=random.uniform(0.5, 1.5))
|
| 97 |
+
|
| 98 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
| 99 |
+
|
| 100 |
+
#saturation distortion
|
| 101 |
+
if random.randrange(2):
|
| 102 |
+
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
|
| 103 |
+
|
| 104 |
+
#hue distortion
|
| 105 |
+
if random.randrange(2):
|
| 106 |
+
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
| 107 |
+
tmp %= 180
|
| 108 |
+
image[:, :, 0] = tmp
|
| 109 |
+
|
| 110 |
+
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
|
| 114 |
+
#brightness distortion
|
| 115 |
+
if random.randrange(2):
|
| 116 |
+
_convert(image, beta=random.uniform(-32, 32))
|
| 117 |
+
|
| 118 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
| 119 |
+
|
| 120 |
+
#saturation distortion
|
| 121 |
+
if random.randrange(2):
|
| 122 |
+
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
|
| 123 |
+
|
| 124 |
+
#hue distortion
|
| 125 |
+
if random.randrange(2):
|
| 126 |
+
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
| 127 |
+
tmp %= 180
|
| 128 |
+
image[:, :, 0] = tmp
|
| 129 |
+
|
| 130 |
+
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
| 131 |
+
|
| 132 |
+
#contrast distortion
|
| 133 |
+
if random.randrange(2):
|
| 134 |
+
_convert(image, alpha=random.uniform(0.5, 1.5))
|
| 135 |
+
|
| 136 |
+
return image
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _expand(image, boxes, fill, p):
|
| 140 |
+
if random.randrange(2):
|
| 141 |
+
return image, boxes
|
| 142 |
+
|
| 143 |
+
height, width, depth = image.shape
|
| 144 |
+
|
| 145 |
+
scale = random.uniform(1, p)
|
| 146 |
+
w = int(scale * width)
|
| 147 |
+
h = int(scale * height)
|
| 148 |
+
|
| 149 |
+
left = random.randint(0, w - width)
|
| 150 |
+
top = random.randint(0, h - height)
|
| 151 |
+
|
| 152 |
+
boxes_t = boxes.copy()
|
| 153 |
+
boxes_t[:, :2] += (left, top)
|
| 154 |
+
boxes_t[:, 2:] += (left, top)
|
| 155 |
+
expand_image = np.empty(
|
| 156 |
+
(h, w, depth),
|
| 157 |
+
dtype=image.dtype)
|
| 158 |
+
expand_image[:, :] = fill
|
| 159 |
+
expand_image[top:top + height, left:left + width] = image
|
| 160 |
+
image = expand_image
|
| 161 |
+
|
| 162 |
+
return image, boxes_t
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _mirror(image, boxes, landms):
|
| 166 |
+
_, width, _ = image.shape
|
| 167 |
+
if random.randrange(2):
|
| 168 |
+
image = image[:, ::-1]
|
| 169 |
+
boxes = boxes.copy()
|
| 170 |
+
boxes[:, 0::2] = width - boxes[:, 2::-2]
|
| 171 |
+
|
| 172 |
+
# landm
|
| 173 |
+
landms = landms.copy()
|
| 174 |
+
landms = landms.reshape([-1, 5, 2])
|
| 175 |
+
landms[:, :, 0] = width - landms[:, :, 0]
|
| 176 |
+
tmp = landms[:, 1, :].copy()
|
| 177 |
+
landms[:, 1, :] = landms[:, 0, :]
|
| 178 |
+
landms[:, 0, :] = tmp
|
| 179 |
+
tmp1 = landms[:, 4, :].copy()
|
| 180 |
+
landms[:, 4, :] = landms[:, 3, :]
|
| 181 |
+
landms[:, 3, :] = tmp1
|
| 182 |
+
landms = landms.reshape([-1, 10])
|
| 183 |
+
|
| 184 |
+
return image, boxes, landms
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _pad_to_square(image, rgb_mean, pad_image_flag):
|
| 188 |
+
if not pad_image_flag:
|
| 189 |
+
return image
|
| 190 |
+
height, width, _ = image.shape
|
| 191 |
+
long_side = max(width, height)
|
| 192 |
+
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
|
| 193 |
+
image_t[:, :] = rgb_mean
|
| 194 |
+
image_t[0:0 + height, 0:0 + width] = image
|
| 195 |
+
return image_t
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _resize_subtract_mean(image, insize, rgb_mean):
|
| 199 |
+
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
|
| 200 |
+
interp_method = interp_methods[random.randrange(5)]
|
| 201 |
+
image = cv2.resize(image, (insize, insize), interpolation=interp_method)
|
| 202 |
+
image = image.astype(np.float32)
|
| 203 |
+
image -= rgb_mean
|
| 204 |
+
return image.transpose(2, 0, 1)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class preproc(object):
|
| 208 |
+
|
| 209 |
+
def __init__(self, img_dim, rgb_means):
|
| 210 |
+
self.img_dim = img_dim
|
| 211 |
+
self.rgb_means = rgb_means
|
| 212 |
+
|
| 213 |
+
def __call__(self, image, targets):
|
| 214 |
+
assert targets.shape[0] > 0, "this image does not have gt"
|
| 215 |
+
|
| 216 |
+
boxes = targets[:, :4].copy()
|
| 217 |
+
labels = targets[:, -1].copy()
|
| 218 |
+
landm = targets[:, 4:-1].copy()
|
| 219 |
+
|
| 220 |
+
image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
|
| 221 |
+
image_t = _distort(image_t)
|
| 222 |
+
image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
|
| 223 |
+
image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
|
| 224 |
+
height, width, _ = image_t.shape
|
| 225 |
+
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
|
| 226 |
+
boxes_t[:, 0::2] /= width
|
| 227 |
+
boxes_t[:, 1::2] /= height
|
| 228 |
+
|
| 229 |
+
landm_t[:, 0::2] /= width
|
| 230 |
+
landm_t[:, 1::2] /= height
|
| 231 |
+
|
| 232 |
+
labels_t = np.expand_dims(labels_t, 1)
|
| 233 |
+
targets_t = np.hstack((boxes_t, landm_t, labels_t))
|
| 234 |
+
|
| 235 |
+
return image_t, targets_t
|
retinaface/data/wider_face.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data as data
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
class WiderFaceDetection(data.Dataset):
|
| 10 |
+
def __init__(self, txt_path, preproc=None):
|
| 11 |
+
self.preproc = preproc
|
| 12 |
+
self.imgs_path = []
|
| 13 |
+
self.words = []
|
| 14 |
+
f = open(txt_path,'r')
|
| 15 |
+
lines = f.readlines()
|
| 16 |
+
isFirst = True
|
| 17 |
+
labels = []
|
| 18 |
+
for line in lines:
|
| 19 |
+
line = line.rstrip()
|
| 20 |
+
if line.startswith('#'):
|
| 21 |
+
if isFirst is True:
|
| 22 |
+
isFirst = False
|
| 23 |
+
else:
|
| 24 |
+
labels_copy = labels.copy()
|
| 25 |
+
self.words.append(labels_copy)
|
| 26 |
+
labels.clear()
|
| 27 |
+
path = line[2:]
|
| 28 |
+
path = txt_path.replace('label.txt','images/') + path
|
| 29 |
+
self.imgs_path.append(path)
|
| 30 |
+
else:
|
| 31 |
+
line = line.split(' ')
|
| 32 |
+
label = [float(x) for x in line]
|
| 33 |
+
labels.append(label)
|
| 34 |
+
|
| 35 |
+
self.words.append(labels)
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.imgs_path)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, index):
|
| 41 |
+
img = cv2.imread(self.imgs_path[index])
|
| 42 |
+
height, width, _ = img.shape
|
| 43 |
+
|
| 44 |
+
labels = self.words[index]
|
| 45 |
+
annotations = np.zeros((0, 15))
|
| 46 |
+
if len(labels) == 0:
|
| 47 |
+
return annotations
|
| 48 |
+
for idx, label in enumerate(labels):
|
| 49 |
+
annotation = np.zeros((1, 15))
|
| 50 |
+
# bbox
|
| 51 |
+
annotation[0, 0] = label[0] # x1
|
| 52 |
+
annotation[0, 1] = label[1] # y1
|
| 53 |
+
annotation[0, 2] = label[0] + label[2] # x2
|
| 54 |
+
annotation[0, 3] = label[1] + label[3] # y2
|
| 55 |
+
|
| 56 |
+
# landmarks
|
| 57 |
+
annotation[0, 4] = label[4] # l0_x
|
| 58 |
+
annotation[0, 5] = label[5] # l0_y
|
| 59 |
+
annotation[0, 6] = label[7] # l1_x
|
| 60 |
+
annotation[0, 7] = label[8] # l1_y
|
| 61 |
+
annotation[0, 8] = label[10] # l2_x
|
| 62 |
+
annotation[0, 9] = label[11] # l2_y
|
| 63 |
+
annotation[0, 10] = label[13] # l3_x
|
| 64 |
+
annotation[0, 11] = label[14] # l3_y
|
| 65 |
+
annotation[0, 12] = label[16] # l4_x
|
| 66 |
+
annotation[0, 13] = label[17] # l4_y
|
| 67 |
+
if (annotation[0, 4]<0):
|
| 68 |
+
annotation[0, 14] = -1
|
| 69 |
+
else:
|
| 70 |
+
annotation[0, 14] = 1
|
| 71 |
+
|
| 72 |
+
annotations = np.append(annotations, annotation, axis=0)
|
| 73 |
+
target = np.array(annotations)
|
| 74 |
+
if self.preproc is not None:
|
| 75 |
+
img, target = self.preproc(img, target)
|
| 76 |
+
|
| 77 |
+
return torch.from_numpy(img), target
|
| 78 |
+
|
| 79 |
+
def detection_collate(batch):
|
| 80 |
+
"""Custom collate fn for dealing with batches of images that have a different
|
| 81 |
+
number of associated object annotations (bounding boxes).
|
| 82 |
+
|
| 83 |
+
Arguments:
|
| 84 |
+
batch: (tuple) A tuple of tensor images and lists of annotations
|
| 85 |
+
|
| 86 |
+
Return:
|
| 87 |
+
A tuple containing:
|
| 88 |
+
1) (tensor) batch of images stacked on their 0 dim
|
| 89 |
+
2) (list of tensors) annotations for a given image are stacked on 0 dim
|
| 90 |
+
"""
|
| 91 |
+
targets = []
|
| 92 |
+
imgs = []
|
| 93 |
+
for _, sample in enumerate(batch):
|
| 94 |
+
for _, tup in enumerate(sample):
|
| 95 |
+
if torch.is_tensor(tup):
|
| 96 |
+
imgs.append(tup)
|
| 97 |
+
elif isinstance(tup, type(np.empty(0))):
|
| 98 |
+
annos = torch.from_numpy(tup).float()
|
| 99 |
+
targets.append(annos)
|
| 100 |
+
|
| 101 |
+
return (torch.stack(imgs, 0), targets)
|
retinaface/detect.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import timeit
|
| 6 |
+
|
| 7 |
+
import imutils
|
| 8 |
+
from utils.infer_utils import load_model
|
| 9 |
+
from data import cfg_mnet as cfg
|
| 10 |
+
from models.retinaface import RetinaFace
|
| 11 |
+
from layers.functions.prior_box import PriorBox
|
| 12 |
+
from utils.box_utils import decode, decode_landm
|
| 13 |
+
from utils.nms.py_cpu_nms import py_cpu_nms
|
| 14 |
+
from utils.infer_utils import align_face
|
| 15 |
+
torch.set_grad_enabled(False)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Detect:
|
| 19 |
+
def __init__(self, weight_path, net_inshape=(180, 320)):
|
| 20 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
+
self.net_inshape = net_inshape
|
| 22 |
+
im_height, im_width = net_inshape
|
| 23 |
+
self.box_scale = np.array([im_width, im_height] * 2)
|
| 24 |
+
self.lmk_scale = np.array([im_width, im_height] * 5)
|
| 25 |
+
|
| 26 |
+
priorbox = PriorBox(cfg, image_size=net_inshape)
|
| 27 |
+
priors = priorbox.forward()
|
| 28 |
+
self.prior_data = priors.to(self.device)
|
| 29 |
+
self.net = RetinaFace(cfg=cfg, phase='test')
|
| 30 |
+
self.net = load_model(self.net, weight_path, False)
|
| 31 |
+
self.net.eval()
|
| 32 |
+
self.net = self.net.to(self.device)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _preprocess(self, image):
|
| 36 |
+
rgb_mean = (104, 117, 123) # bgr order
|
| 37 |
+
h, w = image.shape[:2]
|
| 38 |
+
dx = int(self.net_inshape[1] * h / self.net_inshape[0] - w)
|
| 39 |
+
dy = 0
|
| 40 |
+
if dx < 0:
|
| 41 |
+
dx = 0
|
| 42 |
+
dy = int(self.net_inshape[0] * w / self.net_inshape[1] - h)
|
| 43 |
+
img = cv2.copyMakeBorder(image, 0, dy, 0, dx, borderType=cv2.BORDER_CONSTANT, value=rgb_mean)
|
| 44 |
+
img = cv2.copyMakeBorder(img, 0, img.shape[0], 0, img.shape[1], borderType=cv2.BORDER_CONSTANT, value=rgb_mean)
|
| 45 |
+
|
| 46 |
+
h, w = img.shape[:2]
|
| 47 |
+
resize = float(self.net_inshape[1]) / float(w)
|
| 48 |
+
img = cv2.resize(img, self.net_inshape[::-1])
|
| 49 |
+
img = np.float32(img)
|
| 50 |
+
img -= rgb_mean
|
| 51 |
+
|
| 52 |
+
return img, resize
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def __call__(self, img, verbose=False):
|
| 56 |
+
'''
|
| 57 |
+
bgr image
|
| 58 |
+
'''
|
| 59 |
+
t0 = timeit.default_timer()
|
| 60 |
+
img, resize = self._preprocess(img)
|
| 61 |
+
img = img.transpose(2, 0, 1)
|
| 62 |
+
img = torch.from_numpy(img).unsqueeze(0)
|
| 63 |
+
img = img.to(self.device)
|
| 64 |
+
|
| 65 |
+
t1 = timeit.default_timer()
|
| 66 |
+
loc, conf, landms = self.net(img)
|
| 67 |
+
loc = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 4) for i in loc]
|
| 68 |
+
loc = torch.cat(loc, dim=1)
|
| 69 |
+
conf = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 3) for i in conf]
|
| 70 |
+
conf = torch.cat(conf, dim=1)
|
| 71 |
+
landms = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 10) for i in landms]
|
| 72 |
+
landms = torch.cat(landms, dim=1)
|
| 73 |
+
conf = F.softmax(conf, dim=-1)
|
| 74 |
+
|
| 75 |
+
t2 = timeit.default_timer()
|
| 76 |
+
conf = conf[0]
|
| 77 |
+
scores = conf.squeeze(0).detach().cpu().numpy()[:, 1:]
|
| 78 |
+
scores = np.amax(scores, axis=1)
|
| 79 |
+
|
| 80 |
+
boxes = decode(loc[0], self.prior_data, cfg['variance']) # loc[0]
|
| 81 |
+
boxes = boxes.detach().cpu().numpy()
|
| 82 |
+
boxes = boxes * self.box_scale / resize
|
| 83 |
+
|
| 84 |
+
landms = decode_landm(landms[0], self.prior_data, cfg['variance'])
|
| 85 |
+
landms = landms.detach().cpu().numpy()
|
| 86 |
+
landms = landms * self.lmk_scale / resize
|
| 87 |
+
|
| 88 |
+
# ignore low scores
|
| 89 |
+
inds = np.where(scores > 0.02)[0]
|
| 90 |
+
boxes = boxes[inds]
|
| 91 |
+
landms = landms[inds]
|
| 92 |
+
scores = scores[inds]
|
| 93 |
+
|
| 94 |
+
# keep top-K before NMS
|
| 95 |
+
order = scores.argsort()[::-1][:5000]
|
| 96 |
+
boxes = boxes[order]
|
| 97 |
+
landms = landms[order]
|
| 98 |
+
scores = scores[order]
|
| 99 |
+
|
| 100 |
+
# do NMS
|
| 101 |
+
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
| 102 |
+
keep = py_cpu_nms(dets, 0.4)
|
| 103 |
+
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
| 104 |
+
dets = dets[keep, :]
|
| 105 |
+
landms = landms[keep]
|
| 106 |
+
|
| 107 |
+
# keep top-K faster NMS
|
| 108 |
+
dets = dets[:750, :]
|
| 109 |
+
landms = landms[:750, :]
|
| 110 |
+
|
| 111 |
+
dets = np.concatenate((dets, landms), axis=1)
|
| 112 |
+
dets = dets[dets[:, 4] > 0.5]
|
| 113 |
+
dets = dets[np.argsort(dets, axis=0)[:, 0]]
|
| 114 |
+
|
| 115 |
+
t3 = timeit.default_timer()
|
| 116 |
+
if verbose:
|
| 117 |
+
print(t1 - t0, t2 - t1, t3 - t2)
|
| 118 |
+
|
| 119 |
+
return dets # (n, 15), box=0-3, cls=4, lmk=5-10
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
net_inshape = (486, 864) # h, w
|
| 124 |
+
model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape)
|
| 125 |
+
image_path = "/mnt/nvme0n1p2/datasets/face/dyno/mytelpay230626/mytelpay230626_raw/data_2nd/၁၀ကတန(နိုင်)၀၀၁၀၀၁/mytel_ekyc_1m2_65160cc1802b6183d87fca091cab4c2faa93a9b1614106b5911ca778_front_image.jpg"
|
| 126 |
+
img = cv2.imread(image_path)
|
| 127 |
+
|
| 128 |
+
dets = model(img)
|
| 129 |
+
for i, b in enumerate(dets):
|
| 130 |
+
text = "{:.4f}".format(b[4])
|
| 131 |
+
b = b.astype(np.int32)
|
| 132 |
+
landm = b[5:15]
|
| 133 |
+
landm = landm.reshape((5, 2))
|
| 134 |
+
|
| 135 |
+
alighed_face = align_face(img, landm.copy())
|
| 136 |
+
# cv2.imshow(str(i), alighed_face)
|
| 137 |
+
|
| 138 |
+
# landms
|
| 139 |
+
landm = landm.astype(np.int32)
|
| 140 |
+
cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
|
| 141 |
+
cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
|
| 142 |
+
cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
|
| 143 |
+
cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
|
| 144 |
+
cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
|
| 145 |
+
|
| 146 |
+
cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
|
| 147 |
+
cx = b[0]
|
| 148 |
+
cy = b[1] + 20
|
| 149 |
+
cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255))
|
| 150 |
+
|
| 151 |
+
cv2.imwrite("./output.jpg", img)
|
| 152 |
+
|
retinaface/detect_video_raw.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import imutils
|
| 4 |
+
|
| 5 |
+
from utils.fps import FPS
|
| 6 |
+
from utils.infer_utils import LoadStream, align_face
|
| 7 |
+
from detect import Detect
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
net_inshape = (486, 864) # h, w
|
| 11 |
+
model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape)
|
| 12 |
+
# dataloader = LoadStream("rtsp://admin:meditech123@192.168.100.90:555/")
|
| 13 |
+
dataloader = LoadStream("../30Shine_1.mp4")
|
| 14 |
+
fps = FPS().start()
|
| 15 |
+
|
| 16 |
+
for frame in dataloader:
|
| 17 |
+
# frame = imutils.resize(frame, width=640)
|
| 18 |
+
frame = frame.copy()
|
| 19 |
+
frame_raw = frame.copy()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
dets = model(frame)
|
| 23 |
+
for i, b in enumerate(dets):
|
| 24 |
+
text = "{:.4f}".format(b[4])
|
| 25 |
+
b = b.astype(np.int32)
|
| 26 |
+
landm = b[5:15]
|
| 27 |
+
landm = landm.reshape((5, 2))
|
| 28 |
+
|
| 29 |
+
alighed_face = align_face(frame, landm.copy())
|
| 30 |
+
# cv2.imshow(str(i), alighed_face)
|
| 31 |
+
|
| 32 |
+
# landms
|
| 33 |
+
landm = landm.astype(np.int32)
|
| 34 |
+
cv2.circle(frame, tuple(landm[0]), 1, (0, 0, 255), 2)
|
| 35 |
+
cv2.circle(frame, tuple(landm[1]), 1, (0, 255, 255), 2)
|
| 36 |
+
cv2.circle(frame, tuple(landm[2]), 1, (255, 0, 255), 2)
|
| 37 |
+
cv2.circle(frame, tuple(landm[3]), 1, (0, 255, 0), 2)
|
| 38 |
+
cv2.circle(frame, tuple(landm[4]), 1, (255, 0, 0), 2)
|
| 39 |
+
|
| 40 |
+
cv2.rectangle(frame, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
|
| 41 |
+
cx = b[0]
|
| 42 |
+
cy = b[1] + 20
|
| 43 |
+
cv2.putText(frame, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255))
|
| 44 |
+
|
| 45 |
+
fps.update()
|
| 46 |
+
text_fps = "FPS: {:.3f}".format(fps.get_fps_n())
|
| 47 |
+
cv2.putText(frame, text_fps, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
| 48 |
+
cv2.imshow("frame", imutils.resize(frame, width=1700))
|
| 49 |
+
key = cv2.waitKey(1) & 0xff
|
| 50 |
+
if key == ord("q"):
|
| 51 |
+
break
|
| 52 |
+
elif key == ord("c"):
|
| 53 |
+
while True:
|
| 54 |
+
cv2.imshow("frame", imutils.resize(frame, width=1700))
|
| 55 |
+
key = cv2.waitKey(1) & 0xff
|
| 56 |
+
if key == ord("q"):
|
| 57 |
+
break
|
| 58 |
+
# cv2.imwrite(f"{i}.jpg", alighed_face)
|
| 59 |
+
# i += 1
|
| 60 |
+
# # break
|
| 61 |
+
|
| 62 |
+
print(text_fps)
|
| 63 |
+
cv2.destroyAllWindows()
|
| 64 |
+
fps.stop()
|
| 65 |
+
print("Total FPS: {}".format(fps.fps()))
|
| 66 |
+
dataloader.close()
|
retinaface/layers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functions import *
|
| 2 |
+
from .modules import *
|
retinaface/layers/functions/prior_box.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from itertools import product as product
|
| 3 |
+
import numpy as np
|
| 4 |
+
from math import ceil
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PriorBox(object):
|
| 8 |
+
def __init__(self, cfg, image_size=None, phase='train'):
|
| 9 |
+
super(PriorBox, self).__init__()
|
| 10 |
+
self.min_sizes = cfg['min_sizes']
|
| 11 |
+
self.steps = cfg['steps']
|
| 12 |
+
self.clip = cfg['clip']
|
| 13 |
+
self.image_size = image_size
|
| 14 |
+
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
|
| 15 |
+
|
| 16 |
+
def forward(self):
|
| 17 |
+
anchors = []
|
| 18 |
+
for k, f in enumerate(self.feature_maps):
|
| 19 |
+
min_sizes = self.min_sizes[k]
|
| 20 |
+
for i, j in product(range(f[0]), range(f[1])):
|
| 21 |
+
for min_size in min_sizes:
|
| 22 |
+
s_kx = min_size / self.image_size[1]
|
| 23 |
+
s_ky = min_size / self.image_size[0]
|
| 24 |
+
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
|
| 25 |
+
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
|
| 26 |
+
for cy, cx in product(dense_cy, dense_cx):
|
| 27 |
+
anchors += [cx, cy, s_kx, s_ky]
|
| 28 |
+
|
| 29 |
+
# back to torch land
|
| 30 |
+
output = torch.Tensor(anchors).view(-1, 4)
|
| 31 |
+
if self.clip:
|
| 32 |
+
output.clamp_(max=1, min=0)
|
| 33 |
+
return output
|
retinaface/layers/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .multibox_loss import MultiBoxLoss
|
| 2 |
+
|
| 3 |
+
__all__ = ['MultiBoxLoss']
|
retinaface/layers/modules/multibox_loss.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from utils.box_utils import match, log_sum_exp
|
| 6 |
+
from data import cfg_mnet
|
| 7 |
+
GPU = cfg_mnet['gpu_train']
|
| 8 |
+
|
| 9 |
+
class MultiBoxLoss(nn.Module):
|
| 10 |
+
"""SSD Weighted Loss Function
|
| 11 |
+
Compute Targets:
|
| 12 |
+
1) Produce Confidence Target Indices by matching ground truth boxes
|
| 13 |
+
with (default) 'priorboxes' that have jaccard index > threshold parameter
|
| 14 |
+
(default threshold: 0.5).
|
| 15 |
+
2) Produce localization target by 'encoding' variance into offsets of ground
|
| 16 |
+
truth boxes and their matched 'priorboxes'.
|
| 17 |
+
3) Hard negative mining to filter the excessive number of negative examples
|
| 18 |
+
that comes with using a large number of default bounding boxes.
|
| 19 |
+
(default negative:positive ratio 3:1)
|
| 20 |
+
Objective Loss:
|
| 21 |
+
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
| 22 |
+
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
|
| 23 |
+
weighted by α which is set to 1 by cross val.
|
| 24 |
+
Args:
|
| 25 |
+
c: class confidences,
|
| 26 |
+
l: predicted boxes,
|
| 27 |
+
g: ground truth boxes
|
| 28 |
+
N: number of matched default boxes
|
| 29 |
+
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
|
| 33 |
+
super(MultiBoxLoss, self).__init__()
|
| 34 |
+
self.num_classes = num_classes
|
| 35 |
+
self.threshold = overlap_thresh
|
| 36 |
+
self.background_label = bkg_label
|
| 37 |
+
self.encode_target = encode_target
|
| 38 |
+
self.use_prior_for_matching = prior_for_matching
|
| 39 |
+
self.do_neg_mining = neg_mining
|
| 40 |
+
self.negpos_ratio = neg_pos
|
| 41 |
+
self.neg_overlap = neg_overlap
|
| 42 |
+
self.variance = [0.1, 0.2]
|
| 43 |
+
|
| 44 |
+
def forward(self, predictions, priors, targets):
|
| 45 |
+
"""Multibox Loss
|
| 46 |
+
Args:
|
| 47 |
+
predictions (tuple): A tuple containing loc preds, conf preds,
|
| 48 |
+
and prior boxes from SSD net.
|
| 49 |
+
conf shape: torch.size(batch_size,num_priors,num_classes)
|
| 50 |
+
loc shape: torch.size(batch_size,num_priors,4)
|
| 51 |
+
priors shape: torch.size(num_priors,4)
|
| 52 |
+
|
| 53 |
+
ground_truth (tensor): Ground truth boxes and labels for a batch,
|
| 54 |
+
shape: [batch_size,num_objs,5] (last idx is the label).
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loc_data, conf_data, landm_data = predictions
|
| 58 |
+
priors = priors
|
| 59 |
+
num = loc_data.size(0)
|
| 60 |
+
num_priors = (priors.size(0))
|
| 61 |
+
|
| 62 |
+
# match priors (default boxes) and ground truth boxes
|
| 63 |
+
loc_t = torch.Tensor(num, num_priors, 4)
|
| 64 |
+
landm_t = torch.Tensor(num, num_priors, 10)
|
| 65 |
+
conf_t = torch.LongTensor(num, num_priors)
|
| 66 |
+
for idx in range(num):
|
| 67 |
+
truths = targets[idx][:, :4].data
|
| 68 |
+
labels = targets[idx][:, -1].data
|
| 69 |
+
landms = targets[idx][:, 4:14].data
|
| 70 |
+
defaults = priors.data
|
| 71 |
+
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
|
| 72 |
+
if GPU:
|
| 73 |
+
loc_t = loc_t.cuda()
|
| 74 |
+
conf_t = conf_t.cuda()
|
| 75 |
+
landm_t = landm_t.cuda()
|
| 76 |
+
|
| 77 |
+
zeros = torch.tensor(0).cuda()
|
| 78 |
+
# landm Loss (Smooth L1)
|
| 79 |
+
# Shape: [batch,num_priors,10]
|
| 80 |
+
pos1 = conf_t > zeros
|
| 81 |
+
num_pos_landm = pos1.long().sum(1, keepdim=True)
|
| 82 |
+
N1 = max(num_pos_landm.data.sum().float(), 1)
|
| 83 |
+
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
|
| 84 |
+
landm_p = landm_data[pos_idx1].view(-1, 10)
|
| 85 |
+
landm_t = landm_t[pos_idx1].view(-1, 10)
|
| 86 |
+
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
pos = conf_t != zeros
|
| 90 |
+
conf_t[pos] = 1
|
| 91 |
+
|
| 92 |
+
# Localization Loss (Smooth L1)
|
| 93 |
+
# Shape: [batch,num_priors,4]
|
| 94 |
+
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
|
| 95 |
+
loc_p = loc_data[pos_idx].view(-1, 4)
|
| 96 |
+
loc_t = loc_t[pos_idx].view(-1, 4)
|
| 97 |
+
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
|
| 98 |
+
|
| 99 |
+
# Compute max conf across batch for hard negative mining
|
| 100 |
+
batch_conf = conf_data.view(-1, self.num_classes)
|
| 101 |
+
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
|
| 102 |
+
|
| 103 |
+
# Hard Negative Mining
|
| 104 |
+
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
|
| 105 |
+
loss_c = loss_c.view(num, -1)
|
| 106 |
+
_, loss_idx = loss_c.sort(1, descending=True)
|
| 107 |
+
_, idx_rank = loss_idx.sort(1)
|
| 108 |
+
num_pos = pos.long().sum(1, keepdim=True)
|
| 109 |
+
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
|
| 110 |
+
neg = idx_rank < num_neg.expand_as(idx_rank)
|
| 111 |
+
|
| 112 |
+
# Confidence Loss Including Positive and Negative Examples
|
| 113 |
+
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
|
| 114 |
+
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
|
| 115 |
+
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
|
| 116 |
+
targets_weighted = conf_t[(pos+neg).gt(0)]
|
| 117 |
+
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
|
| 118 |
+
|
| 119 |
+
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
| 120 |
+
N = max(num_pos.data.sum().float(), 1)
|
| 121 |
+
loss_l /= N
|
| 122 |
+
loss_c /= N
|
| 123 |
+
loss_landm /= N1
|
| 124 |
+
|
| 125 |
+
return loss_l, loss_c, loss_landm
|