diff --git a/TPSMM/LICENSE b/TPSMM/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..cde95a5110a80285e68865acd7ffda00faebed8b
--- /dev/null
+++ b/TPSMM/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 yoyo-nb
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/TPSMM/README.md b/TPSMM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..812a4d99249e970aab5321c9d542c0cc0975c544
--- /dev/null
+++ b/TPSMM/README.md
@@ -0,0 +1,98 @@
+# [CVPR2022] Thin-Plate Spline Motion Model for Image Animation
+
+[](LICENSE)
+
+
+
+Source code of the CVPR'2022 paper "Thin-Plate Spline Motion Model for Image Animation"
+
+[**Paper**](https://arxiv.org/abs/2203.14367) **|** [**Supp**](https://cloud.tsinghua.edu.cn/f/f7b8573bb5b04583949f/?dl=1)
+
+### Example animation
+
+
+
+
+**PS**: The paper trains the model for 100 epochs for a fair comparison. You can use more data and train for more epochs to get better performance.
+
+
+### Web demo for animation
+- Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model)
+- Try the web demo for animation here: [](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model)
+- Google Colab: [](https://colab.research.google.com/drive/1DREfdpnaBhqISg0fuQlAAIwyGVn1loH_?usp=sharing)
+
+### Pre-trained models
+- [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/30ab8765da364fefa101/)
+- [Google Drive](https://drive.google.com/drive/folders/1pNDo1ODQIb5HVObRtCmubqJikmR7VVLT?usp=sharing)
+
+### Installation
+
+We support ```python3```.(Recommended version is Python 3.9).
+To install the dependencies run:
+```bash
+pip install -r requirements.txt
+```
+
+
+### YAML configs
+
+There are several configuration files one for each `dataset` in the `config` folder named as ```config/dataset_name.yaml```.
+
+See description of the parameters in the ```config/taichi-256.yaml```.
+
+### Datasets
+
+1) **MGif**. Follow [Monkey-Net](https://github.com/AliaksandrSiarohin/monkey-net).
+
+2) **TaiChiHD** and **VoxCeleb**. Follow instructions from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
+
+3) **TED-talks**. Follow instructions from [MRAA](https://github.com/snap-research/articulated-animation).
+
+
+### Training
+To train a model on specific dataset run:
+```
+CUDA_VISIBLE_DEVICES=0,1 python run.py --config config/dataset_name.yaml --device_ids 0,1
+```
+A log folder named after the timestamp will be created. Checkpoints, loss values, reconstruction results will be saved to this folder.
+
+
+#### Training AVD network
+To train a model on specific dataset run:
+```
+CUDA_VISIBLE_DEVICES=0 python run.py --mode train_avd --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' --config config/dataset_name.yaml
+```
+Checkpoints, loss values, reconstruction results will be saved to `{checkpoint_folder}`.
+
+
+
+### Evaluation on video reconstruction
+
+To evaluate the reconstruction performance run:
+```
+CUDA_VISIBLE_DEVICES=0 python run.py --mode reconstruction --config config/dataset_name.yaml --checkpoint '{checkpoint_folder}/checkpoint.pth.tar'
+```
+The `reconstruction` subfolder will be created in `{checkpoint_folder}`.
+The generated video will be stored to this folder, also generated videos will be stored in ```png``` subfolder in loss-less '.png' format for evaluation.
+To compute metrics, follow instructions from [pose-evaluation](https://github.com/AliaksandrSiarohin/pose-evaluation).
+
+
+### Image animation demo
+- notebook: `demo.ipynb`, edit the config cell and run for image animation.
+- python:
+```bash
+CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-256.yaml --checkpoint checkpoints/vox.pth.tar --source_image ./source.jpg --driving_video ./driving.mp4
+```
+
+# Acknowledgments
+The main code is based upon [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) and [MRAA](https://github.com/snap-research/articulated-animation)
+
+Thanks for the excellent works!
+
+And Thanks to:
+
+- [@chenxwh](https://github.com/chenxwh): Add Web Demo & Docker environment [](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model)
+
+- [@TalkUHulk](https://github.com/TalkUHulk): The C++/Python demo is provided in [Image-Animation-Turbo-Boost](https://github.com/TalkUHulk/Image-Animation-Turbo-Boost)
+
+- [@AK391](https://github.com/AK391): Add huggingface web demo [](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model)
\ No newline at end of file
diff --git a/TPSMM/assets/source.png b/TPSMM/assets/source.png
new file mode 100644
index 0000000000000000000000000000000000000000..241ebbd038565e5a78cb5c5e738f6a8820b18c47
Binary files /dev/null and b/TPSMM/assets/source.png differ
diff --git a/TPSMM/assets/source1.png b/TPSMM/assets/source1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c402a3391ecfba923ef842a8fab9097b31289992
Binary files /dev/null and b/TPSMM/assets/source1.png differ
diff --git a/TPSMM/assets/source2.jpg b/TPSMM/assets/source2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0f41c2eab8b498fb40fe81d5e2558d29b0cbf365
Binary files /dev/null and b/TPSMM/assets/source2.jpg differ
diff --git a/TPSMM/augmentation.py b/TPSMM/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..df77004a1b7093c0992c970ed0a337b073ddfe86
--- /dev/null
+++ b/TPSMM/augmentation.py
@@ -0,0 +1,344 @@
+"""
+Code from https://github.com/hassony2/torch_videovision
+"""
+
+import numbers
+
+import random
+import numpy as np
+import PIL
+
+from skimage.transform import resize, rotate
+import torchvision
+
+import warnings
+
+from skimage import img_as_ubyte, img_as_float
+
+
+def crop_clip(clip, min_h, min_w, h, w):
+ if isinstance(clip[0], np.ndarray):
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
+
+ elif isinstance(clip[0], PIL.Image.Image):
+ cropped = [
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
+ ]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ return cropped
+
+
+def pad_clip(clip, h, w):
+ im_h, im_w = clip[0].shape[:2]
+ pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
+ pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
+
+ return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
+
+
+def resize_clip(clip, size, interpolation='bilinear'):
+ if isinstance(clip[0], np.ndarray):
+ if isinstance(size, numbers.Number):
+ im_h, im_w, im_c = clip[0].shape
+ # Min spatial dim already matches minimal size
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
+ and im_h == size):
+ return clip
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
+ size = (new_w, new_h)
+ else:
+ size = size[1], size[0]
+
+ scaled = [
+ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
+ mode='constant', anti_aliasing=True) for img in clip
+ ]
+ elif isinstance(clip[0], PIL.Image.Image):
+ if isinstance(size, numbers.Number):
+ im_w, im_h = clip[0].size
+ # Min spatial dim already matches minimal size
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
+ and im_h == size):
+ return clip
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
+ size = (new_w, new_h)
+ else:
+ size = size[1], size[0]
+ if interpolation == 'bilinear':
+ pil_inter = PIL.Image.NEAREST
+ else:
+ pil_inter = PIL.Image.BILINEAR
+ scaled = [img.resize(size, pil_inter) for img in clip]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ return scaled
+
+
+def get_resize_sizes(im_h, im_w, size):
+ if im_w < im_h:
+ ow = size
+ oh = int(size * im_h / im_w)
+ else:
+ oh = size
+ ow = int(size * im_w / im_h)
+ return oh, ow
+
+
+class RandomFlip(object):
+ def __init__(self, time_flip=False, horizontal_flip=False):
+ self.time_flip = time_flip
+ self.horizontal_flip = horizontal_flip
+
+ def __call__(self, clip):
+ if random.random() < 0.5 and self.time_flip:
+ return clip[::-1]
+ if random.random() < 0.5 and self.horizontal_flip:
+ return [np.fliplr(img) for img in clip]
+
+ return clip
+
+
+class RandomResize(object):
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
+ The larger the original image is, the more times it takes to
+ interpolate
+ Args:
+ interpolation (str): Can be one of 'nearest', 'bilinear'
+ defaults to nearest
+ size (tuple): (widht, height)
+ """
+
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
+ self.ratio = ratio
+ self.interpolation = interpolation
+
+ def __call__(self, clip):
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
+
+ if isinstance(clip[0], np.ndarray):
+ im_h, im_w, im_c = clip[0].shape
+ elif isinstance(clip[0], PIL.Image.Image):
+ im_w, im_h = clip[0].size
+
+ new_w = int(im_w * scaling_factor)
+ new_h = int(im_h * scaling_factor)
+ new_size = (new_w, new_h)
+ resized = resize_clip(
+ clip, new_size, interpolation=self.interpolation)
+
+ return resized
+
+
+class RandomCrop(object):
+ """Extract random crop at the same location for a list of videos
+ Args:
+ size (sequence or int): Desired output size for the
+ crop in format (h, w)
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ size = (size, size)
+
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Cropped list of videos
+ """
+ h, w = self.size
+ if isinstance(clip[0], np.ndarray):
+ im_h, im_w, im_c = clip[0].shape
+ elif isinstance(clip[0], PIL.Image.Image):
+ im_w, im_h = clip[0].size
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+
+ clip = pad_clip(clip, h, w)
+ im_h, im_w = clip.shape[1:3]
+ x1 = 0 if h == im_h else random.randint(0, im_w - w)
+ y1 = 0 if w == im_w else random.randint(0, im_h - h)
+ cropped = crop_clip(clip, y1, x1, h, w)
+
+ return cropped
+
+
+class RandomRotation(object):
+ """Rotate entire clip randomly by a random angle within
+ given bounds
+ Args:
+ degrees (sequence or int): Range of degrees to select from
+ If degrees is a number instead of sequence like (min, max),
+ the range of degrees, will be (-degrees, +degrees).
+ """
+
+ def __init__(self, degrees):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError('If degrees is a single number,'
+ 'must be positive')
+ degrees = (-degrees, degrees)
+ else:
+ if len(degrees) != 2:
+ raise ValueError('If degrees is a sequence,'
+ 'it must be of len 2.')
+
+ self.degrees = degrees
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Cropped list of videos
+ """
+ angle = random.uniform(self.degrees[0], self.degrees[1])
+ if isinstance(clip[0], np.ndarray):
+ rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
+ elif isinstance(clip[0], PIL.Image.Image):
+ rotated = [img.rotate(angle) for img in clip]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+
+ return rotated
+
+
+class ColorJitter(object):
+ """Randomly change the brightness, contrast and saturation and hue of the clip
+ Args:
+ brightness (float): How much to jitter brightness. brightness_factor
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
+ contrast (float): How much to jitter contrast. contrast_factor
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
+ saturation (float): How much to jitter saturation. saturation_factor
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
+ [-hue, hue]. Should be >=0 and <= 0.5.
+ """
+
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+ self.brightness = brightness
+ self.contrast = contrast
+ self.saturation = saturation
+ self.hue = hue
+
+ def get_params(self, brightness, contrast, saturation, hue):
+ if brightness > 0:
+ brightness_factor = random.uniform(
+ max(0, 1 - brightness), 1 + brightness)
+ else:
+ brightness_factor = None
+
+ if contrast > 0:
+ contrast_factor = random.uniform(
+ max(0, 1 - contrast), 1 + contrast)
+ else:
+ contrast_factor = None
+
+ if saturation > 0:
+ saturation_factor = random.uniform(
+ max(0, 1 - saturation), 1 + saturation)
+ else:
+ saturation_factor = None
+
+ if hue > 0:
+ hue_factor = random.uniform(-hue, hue)
+ else:
+ hue_factor = None
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (list): list of PIL.Image
+ Returns:
+ list PIL.Image : list of transformed PIL.Image
+ """
+ if isinstance(clip[0], np.ndarray):
+ brightness, contrast, saturation, hue = self.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue)
+
+ # Create img transform function sequence
+ img_transforms = []
+ if brightness is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
+ if saturation is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
+ if hue is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
+ if contrast is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
+ random.shuffle(img_transforms)
+ img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
+ img_as_float]
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ jittered_clip = []
+ for img in clip:
+ jittered_img = img
+ for func in img_transforms:
+ jittered_img = func(jittered_img)
+ jittered_clip.append(jittered_img.astype('float32'))
+ elif isinstance(clip[0], PIL.Image.Image):
+ brightness, contrast, saturation, hue = self.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue)
+
+ # Create img transform function sequence
+ img_transforms = []
+ if brightness is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
+ if saturation is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
+ if hue is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
+ if contrast is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
+ random.shuffle(img_transforms)
+
+ # Apply to all videos
+ jittered_clip = []
+ for img in clip:
+ for func in img_transforms:
+ jittered_img = func(img)
+ jittered_clip.append(jittered_img)
+
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ return jittered_clip
+
+
+class AllAugmentationTransform:
+ def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
+ self.transforms = []
+
+ if flip_param is not None:
+ self.transforms.append(RandomFlip(**flip_param))
+
+ if rotation_param is not None:
+ self.transforms.append(RandomRotation(**rotation_param))
+
+ if resize_param is not None:
+ self.transforms.append(RandomResize(**resize_param))
+
+ if crop_param is not None:
+ self.transforms.append(RandomCrop(**crop_param))
+
+ if jitter_param is not None:
+ self.transforms.append(ColorJitter(**jitter_param))
+
+ def __call__(self, clip):
+ for t in self.transforms:
+ clip = t(clip)
+ return clip
diff --git a/TPSMM/cog.yaml b/TPSMM/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..203a577ae7a37e77e9a349946cfa4e4fb9638590
--- /dev/null
+++ b/TPSMM/cog.yaml
@@ -0,0 +1,40 @@
+build:
+ cuda: "11.0"
+ gpu: true
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ - "ninja-build"
+ python_packages:
+ - "ipython==7.21.0"
+ - "torch==1.10.1"
+ - "torchvision==0.11.2"
+ - "cffi==1.14.6"
+ - "cycler==0.10.0"
+ - "decorator==5.1.0"
+ - "face-alignment==1.3.5"
+ - "imageio==2.9.0"
+ - "imageio-ffmpeg==0.4.5"
+ - "kiwisolver==1.3.2"
+ - "matplotlib==3.4.3"
+ - "networkx==2.6.3"
+ - "numpy==1.20.3"
+ - "pandas==1.3.3"
+ - "Pillow==8.3.2"
+ - "pycparser==2.20"
+ - "pyparsing==2.4.7"
+ - "python-dateutil==2.8.2"
+ - "pytz==2021.1"
+ - "PyWavelets==1.1.1"
+ - "PyYAML==5.4.1"
+ - "scikit-image==0.18.3"
+ - "scikit-learn==1.0"
+ - "scipy==1.7.1"
+ - "six==1.16.0"
+ - "tqdm==4.62.3"
+ - "cmake==3.21.3"
+ run:
+ - pip install dlib
+
+predict: "predict.py:Predictor"
diff --git a/TPSMM/config/mgif-256.yaml b/TPSMM/config/mgif-256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4cd2011114f0b8ec69a76b4a210e134b48dea2cc
--- /dev/null
+++ b/TPSMM/config/mgif-256.yaml
@@ -0,0 +1,75 @@
+dataset_params:
+ root_dir: ../moving-gif
+ frame_shape: null
+ id_sampling: False
+ augmentation_params:
+ flip_param:
+ horizontal_flip: True
+ time_flip: True
+ crop_param:
+ size: [256, 256]
+ resize_param:
+ ratio: [0.9, 1.1]
+ jitter_param:
+ hue: 0.5
+
+model_params:
+ common_params:
+ num_tps: 10
+ num_channels: 3
+ bg: False
+ multi_mask: True
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 3
+ dense_motion_params:
+ block_expansion: 64
+ max_features: 1024
+ num_blocks: 5
+ scale_factor: 0.25
+ avd_network_params:
+ id_bottle_size: 128
+ pose_bottle_size: 128
+
+
+train_params:
+ num_epochs: 100
+ num_repeats: 50
+ epoch_milestones: [70, 90]
+ lr_generator: 2.0e-4
+ batch_size: 28
+ scales: [1, 0.5, 0.25, 0.125]
+ dataloader_workers: 12
+ checkpoint_freq: 50
+ dropout_epoch: 35
+ dropout_maxp: 0.5
+ dropout_startp: 0.2
+ dropout_inc_epoch: 10
+ bg_start: 0
+ transform_params:
+ sigma_affine: 0.05
+ sigma_tps: 0.005
+ points_tps: 5
+ loss_weights:
+ perceptual: [10, 10, 10, 10, 10]
+ equivariance_value: 10
+ warp_loss: 10
+ bg: 10
+
+train_avd_params:
+ num_epochs: 100
+ num_repeats: 50
+ batch_size: 256
+ dataloader_workers: 24
+ checkpoint_freq: 10
+ epoch_milestones: [70, 90]
+ lr: 1.0e-3
+ lambda_shift: 1
+ lambda_affine: 1
+ random_scale: 0.25
+
+visualizer_params:
+ kp_size: 5
+ draw_border: True
+ colormap: 'gist_rainbow'
\ No newline at end of file
diff --git a/TPSMM/config/taichi-256.yaml b/TPSMM/config/taichi-256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2eb1d3a74634b5d1fc78cc5d6d368c7528d45750
--- /dev/null
+++ b/TPSMM/config/taichi-256.yaml
@@ -0,0 +1,134 @@
+# Dataset parameters
+# Each dataset should contain 2 folders train and test
+# Each video can be represented as:
+# - an image of concatenated frames
+# - '.mp4' or '.gif'
+# - folder with all frames from a specific video
+# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
+# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
+# video id.
+dataset_params:
+ # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
+ root_dir: ../taichi
+ # Image shape, needed for staked .png format.
+ frame_shape: null
+ # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
+ # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
+ # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
+ id_sampling: True
+ # Augmentation parameters see augmentation.py for all posible augmentations
+ augmentation_params:
+ flip_param:
+ horizontal_flip: True
+ time_flip: True
+ jitter_param:
+ brightness: 0.1
+ contrast: 0.1
+ saturation: 0.1
+ hue: 0.1
+
+# Defines model architecture
+model_params:
+ common_params:
+ # Number of TPS transformation
+ num_tps: 10
+ # Number of channels per image
+ num_channels: 3
+ # Whether to estimate affine background transformation
+ bg: True
+ # Whether to estimate the multi-resolution occlusion masks
+ multi_mask: True
+ generator_params:
+ # Number of features mutliplier
+ block_expansion: 64
+ # Maximum allowed number of features
+ max_features: 512
+ # Number of downsampling blocks and Upsampling blocks.
+ num_down_blocks: 3
+ dense_motion_params:
+ # Number of features mutliplier
+ block_expansion: 64
+ # Maximum allowed number of features
+ max_features: 1024
+ # Number of block in Unet.
+ num_blocks: 5
+ # Optical flow is predicted on smaller images for better performance,
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
+ scale_factor: 0.25
+ avd_network_params:
+ # Bottleneck for identity branch
+ id_bottle_size: 128
+ # Bottleneck for pose branch
+ pose_bottle_size: 128
+
+# Parameters of training
+train_params:
+ # Number of training epochs
+ num_epochs: 100
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
+ # Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
+ num_repeats: 150
+ # Drop learning rate by 10 times after this epochs
+ epoch_milestones: [70, 90]
+ # Initial learing rate for all modules
+ lr_generator: 2.0e-4
+ batch_size: 28
+ # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
+ # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
+ scales: [1, 0.5, 0.25, 0.125]
+ # Dataset preprocessing cpu workers
+ dataloader_workers: 12
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
+ checkpoint_freq: 50
+ # Parameters of dropout
+ # The first dropout_epoch training uses dropout operation
+ dropout_epoch: 35
+ # The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs
+ dropout_maxp: 0.7
+ dropout_startp: 0.0
+ dropout_inc_epoch: 10
+ # Estimate affine background transformation from the bg_start epoch.
+ bg_start: 0
+ # Parameters of random TPS transformation for equivariance loss
+ transform_params:
+ # Sigma for affine part
+ sigma_affine: 0.05
+ # Sigma for deformation part
+ sigma_tps: 0.005
+ # Number of point in the deformation grid
+ points_tps: 5
+ loss_weights:
+ # Weights for perceptual loss.
+ perceptual: [10, 10, 10, 10, 10]
+ # Weights for value equivariance.
+ equivariance_value: 10
+ # Weights for warp loss.
+ warp_loss: 10
+ # Weights for bg loss.
+ bg: 10
+
+# Parameters of training (animation-via-disentanglement)
+train_avd_params:
+ # Number of training epochs, visualization is produced after each epoch.
+ num_epochs: 100
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
+ # Thus effectively with num_repeats=100 each epoch is 100 times larger.
+ num_repeats: 150
+ # Batch size.
+ batch_size: 256
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
+ checkpoint_freq: 10
+ # Dataset preprocessing cpu workers
+ dataloader_workers: 24
+ # Drop learning rate 10 times after this epochs
+ epoch_milestones: [70, 90]
+ # Initial learning rate
+ lr: 1.0e-3
+ # Weights for equivariance loss.
+ lambda_shift: 1
+ random_scale: 0.25
+
+visualizer_params:
+ kp_size: 5
+ draw_border: True
+ colormap: 'gist_rainbow'
\ No newline at end of file
diff --git a/TPSMM/config/ted-384.yaml b/TPSMM/config/ted-384.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..007b9126823229c70459f51b173773f66f9f2be5
--- /dev/null
+++ b/TPSMM/config/ted-384.yaml
@@ -0,0 +1,73 @@
+dataset_params:
+ root_dir: ../TED384-v2
+ frame_shape: null
+ id_sampling: True
+ augmentation_params:
+ flip_param:
+ horizontal_flip: True
+ time_flip: True
+ jitter_param:
+ brightness: 0.1
+ contrast: 0.1
+ saturation: 0.1
+ hue: 0.1
+
+model_params:
+ common_params:
+ num_tps: 10
+ num_channels: 3
+ bg: True
+ multi_mask: True
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 3
+ dense_motion_params:
+ block_expansion: 64
+ max_features: 1024
+ num_blocks: 5
+ scale_factor: 0.25
+ avd_network_params:
+ id_bottle_size: 128
+ pose_bottle_size: 128
+
+
+train_params:
+ num_epochs: 100
+ num_repeats: 150
+ epoch_milestones: [70, 90]
+ lr_generator: 2.0e-4
+ batch_size: 12
+ scales: [1, 0.5, 0.25, 0.125]
+ dataloader_workers: 6
+ checkpoint_freq: 50
+ dropout_epoch: 35
+ dropout_maxp: 0.5
+ dropout_startp: 0.0
+ dropout_inc_epoch: 10
+ bg_start: 0
+ transform_params:
+ sigma_affine: 0.05
+ sigma_tps: 0.005
+ points_tps: 5
+ loss_weights:
+ perceptual: [10, 10, 10, 10, 10]
+ equivariance_value: 10
+ warp_loss: 10
+ bg: 10
+
+train_avd_params:
+ num_epochs: 30
+ num_repeats: 500
+ batch_size: 256
+ dataloader_workers: 24
+ checkpoint_freq: 10
+ epoch_milestones: [20, 25]
+ lr: 1.0e-3
+ lambda_shift: 1
+ random_scale: 0.25
+
+visualizer_params:
+ kp_size: 5
+ draw_border: True
+ colormap: 'gist_rainbow'
diff --git a/TPSMM/config/vox-256.yaml b/TPSMM/config/vox-256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..658a8163e92574f2b52cc2bbc6886fb1603d7c55
--- /dev/null
+++ b/TPSMM/config/vox-256.yaml
@@ -0,0 +1,74 @@
+dataset_params:
+ root_dir: ../vox
+ frame_shape: null
+ id_sampling: True
+ augmentation_params:
+ flip_param:
+ horizontal_flip: True
+ time_flip: True
+ jitter_param:
+ brightness: 0.1
+ contrast: 0.1
+ saturation: 0.1
+ hue: 0.1
+
+
+model_params:
+ common_params:
+ num_tps: 10
+ num_channels: 3
+ bg: True
+ multi_mask: True
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 3
+ dense_motion_params:
+ block_expansion: 64
+ max_features: 1024
+ num_blocks: 5
+ scale_factor: 0.25
+ avd_network_params:
+ id_bottle_size: 128
+ pose_bottle_size: 128
+
+
+train_params:
+ num_epochs: 100
+ num_repeats: 75
+ epoch_milestones: [70, 90]
+ lr_generator: 2.0e-4
+ batch_size: 28
+ scales: [1, 0.5, 0.25, 0.125]
+ dataloader_workers: 12
+ checkpoint_freq: 50
+ dropout_epoch: 35
+ dropout_maxp: 0.3
+ dropout_startp: 0.1
+ dropout_inc_epoch: 10
+ bg_start: 10
+ transform_params:
+ sigma_affine: 0.05
+ sigma_tps: 0.005
+ points_tps: 5
+ loss_weights:
+ perceptual: [10, 10, 10, 10, 10]
+ equivariance_value: 10
+ warp_loss: 10
+ bg: 10
+
+train_avd_params:
+ num_epochs: 200
+ num_repeats: 300
+ batch_size: 256
+ dataloader_workers: 24
+ checkpoint_freq: 50
+ epoch_milestones: [140, 180]
+ lr: 1.0e-3
+ lambda_shift: 1
+ random_scale: 0.25
+
+visualizer_params:
+ kp_size: 5
+ draw_border: True
+ colormap: 'gist_rainbow'
\ No newline at end of file
diff --git a/TPSMM/demo.ipynb b/TPSMM/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8bfccd9bfa013a06a2301373684ed82ead07246f
--- /dev/null
+++ b/TPSMM/demo.ipynb
@@ -0,0 +1,5113 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Config**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "# edit the config\n",
+ "device = torch.device('cuda:0')\n",
+ "dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']\n",
+ "source_image_path = './assets/source.png'\n",
+ "driving_video_path = './assets/driving.mp4'\n",
+ "output_video_path = './generated.mp4'\n",
+ "config_path = 'config/vox-256.yaml'\n",
+ "checkpoint_path = 'checkpoints/vox.pth.tar'\n",
+ "predict_mode = 'relative' # ['standard', 'relative', 'avd']\n",
+ "find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result\n",
+ "\n",
+ "pixel = 256 # for vox, taichi and mgif, the resolution is 256*256\n",
+ "if(dataset_name == 'ted'): # for ted, the resolution is 384*384\n",
+ " pixel = 384\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Read image and video**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 453
+ },
+ "id": "Oxi6-riLOgnm",
+ "outputId": "d38a8850-9eb1-4de4-9bf2-24cbd847ca1f"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import imageio\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.animation as animation\n",
+ "from skimage.transform import resize\n",
+ "from IPython.display import HTML\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "source_image = imageio.imread(source_image_path)\n",
+ "reader = imageio.get_reader(driving_video_path)\n",
+ "\n",
+ "\n",
+ "source_image = resize(source_image, (pixel, pixel))[..., :3]\n",
+ "\n",
+ "fps = reader.get_meta_data()['fps']\n",
+ "driving_video = []\n",
+ "try:\n",
+ " for im in reader:\n",
+ " driving_video.append(im)\n",
+ "except RuntimeError:\n",
+ " pass\n",
+ "reader.close()\n",
+ "\n",
+ "driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]\n",
+ "\n",
+ "def display(source, driving, generated=None):\n",
+ " fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n",
+ "\n",
+ " ims = []\n",
+ " for i in range(len(driving)):\n",
+ " cols = [source]\n",
+ " cols.append(driving[i])\n",
+ " if generated is not None:\n",
+ " cols.append(generated[i])\n",
+ " im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n",
+ " plt.axis('off')\n",
+ " ims.append([im])\n",
+ "\n",
+ " ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n",
+ " plt.close()\n",
+ " return ani\n",
+ " \n",
+ "\n",
+ "HTML(display(source_image, driving_video).to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xjM7ubVfWrwT"
+ },
+ "source": [
+ "**Create a model and load checkpoints**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "3FQiXqQPWt5B"
+ },
+ "outputs": [],
+ "source": [
+ "from demo import load_checkpoints\n",
+ "inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fdFdasHEj3t7"
+ },
+ "source": [
+ "**Perform image animation**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 471
+ },
+ "id": "SB12II11kF4c",
+ "outputId": "9e2274aa-fd55-4eed-cb50-bec72fcfb8b9"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:10<00:00, 15.69it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from demo import make_animation\n",
+ "from skimage import img_as_ubyte\n",
+ "\n",
+ "if predict_mode=='relative' and find_best_frame:\n",
+ " from demo import find_best_frame as _find\n",
+ " i = _find(source_image, driving_video, device.type=='cpu')\n",
+ " print (\"Best frame: \" + str(i))\n",
+ " driving_forward = driving_video[i:]\n",
+ " driving_backward = driving_video[:(i+1)][::-1]\n",
+ " predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
+ " predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
+ " predictions = predictions_backward[::-1] + predictions_forward[1:]\n",
+ "else:\n",
+ " predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
+ "\n",
+ "#save resulting video\n",
+ "imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
+ "\n",
+ "HTML(display(source_image, driving_video, predictions).to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "name": "first-order-model-demo.ipynb",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/TPSMM/demo.py b/TPSMM/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..011092e02b9a9221444f244a9e4a3fedcffa4cbf
--- /dev/null
+++ b/TPSMM/demo.py
@@ -0,0 +1,180 @@
+import matplotlib
+matplotlib.use('Agg')
+import sys
+import yaml
+from argparse import ArgumentParser
+from tqdm import tqdm
+from scipy.spatial import ConvexHull
+import numpy as np
+import imageio
+from skimage.transform import resize
+from skimage import img_as_ubyte
+import torch
+from modules.inpainting_network import InpaintingNetwork
+from modules.keypoint_detector import KPDetector
+from modules.dense_motion import DenseMotionNetwork
+from modules.avd_network import AVDNetwork
+
+if sys.version_info[0] < 3:
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
+
+def relative_kp(kp_source, kp_driving, kp_driving_initial):
+
+ source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
+ driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
+
+ kp_new = {k: v for k, v in kp_driving.items()}
+
+ kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
+ kp_value_diff *= adapt_movement_scale
+ kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
+
+ return kp_new
+
+def load_checkpoints(config_path, checkpoint_path, device):
+ with open(config_path) as f:
+ config = yaml.full_load(f)
+
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+ kp_detector = KPDetector(**config['model_params']['common_params'])
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
+ **config['model_params']['dense_motion_params'])
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
+ **config['model_params']['avd_network_params'])
+ kp_detector.to(device)
+ dense_motion_network.to(device)
+ inpainting.to(device)
+ avd_network.to(device)
+
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ inpainting.load_state_dict(checkpoint['inpainting_network'])
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
+ if 'avd_network' in checkpoint:
+ avd_network.load_state_dict(checkpoint['avd_network'])
+
+ inpainting.eval()
+ kp_detector.eval()
+ dense_motion_network.eval()
+ avd_network.eval()
+
+ return inpainting, kp_detector, dense_motion_network, avd_network
+
+
+def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
+ assert mode in ['standard', 'relative', 'avd']
+ with torch.no_grad():
+ predictions = []
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
+ source = source.to(device)
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
+ kp_source = kp_detector(source)
+ kp_driving_initial = kp_detector(driving[:, :, 0])
+
+ for frame_idx in tqdm(range(driving.shape[2])):
+ driving_frame = driving[:, :, frame_idx]
+ driving_frame = driving_frame.to(device)
+ kp_driving = kp_detector(driving_frame)
+ if mode == 'standard':
+ kp_norm = kp_driving
+ elif mode=='relative':
+ kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
+ kp_driving_initial=kp_driving_initial)
+ elif mode == 'avd':
+ kp_norm = avd_network(kp_source, kp_driving)
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
+ kp_source=kp_source, bg_param = None,
+ dropout_flag = False)
+ out = inpainting_network(source, dense_motion)
+
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
+ return predictions
+
+
+def find_best_frame(source, driving, cpu):
+ import face_alignment
+
+ def normalize_kp(kp):
+ kp = kp - kp.mean(axis=0, keepdims=True)
+ area = ConvexHull(kp[:, :2]).volume
+ area = np.sqrt(area)
+ kp[:, :2] = kp[:, :2] / area
+ return kp
+
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
+ device= 'cpu' if cpu else 'cuda')
+ kp_source = fa.get_landmarks(255 * source)[0]
+ kp_source = normalize_kp(kp_source)
+ norm = float('inf')
+ frame_num = 0
+ for i, image in tqdm(enumerate(driving)):
+ try:
+ kp_driving = fa.get_landmarks(255 * image)[0]
+ kp_driving = normalize_kp(kp_driving)
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
+ if new_norm < norm:
+ norm = new_norm
+ frame_num = i
+ except:
+ pass
+ return frame_num
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("--config", required=True, help="path to config")
+ parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore")
+
+ parser.add_argument("--source_image", default='./assets/source.png', help="path to source image")
+ parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video")
+ parser.add_argument("--result_video", default='./result.mp4', help="path to output")
+
+ parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))),
+ help='Shape of image, that the model was trained on.')
+
+ parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
+
+ parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
+ help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
+
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
+
+ opt = parser.parse_args()
+
+ source_image = imageio.imread(opt.source_image)
+ reader = imageio.get_reader(opt.driving_video)
+ fps = reader.get_meta_data()['fps']
+ driving_video = []
+ try:
+ for im in reader:
+ driving_video.append(im)
+ except RuntimeError:
+ pass
+ reader.close()
+
+ if opt.cpu:
+ device = torch.device('cpu')
+ else:
+ device = torch.device('cuda')
+
+ source_image = resize(source_image, opt.img_shape)[..., :3]
+ driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video]
+ inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
+
+ if opt.find_best_frame:
+ i = find_best_frame(source_image, driving_video, opt.cpu)
+ print ("Best frame: " + str(i))
+ driving_forward = driving_video[i:]
+ driving_backward = driving_video[:(i+1)][::-1]
+ predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
+ predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
+ else:
+ predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
+
+ predictions =[img_as_ubyte(frame) for frame in predictions]
+ imageio.mimsave(opt.result_video, predictions, fps=fps)
+
diff --git a/TPSMM/frames_dataset.py b/TPSMM/frames_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2328c6a0af028518fe5412ebc75b2b8a019d77c8
--- /dev/null
+++ b/TPSMM/frames_dataset.py
@@ -0,0 +1,173 @@
+import os
+from skimage import io, img_as_float32
+from skimage.color import gray2rgb
+from sklearn.model_selection import train_test_split
+from imageio import mimread
+from skimage.transform import resize
+import numpy as np
+from torch.utils.data import Dataset
+from augmentation import AllAugmentationTransform
+import glob
+from functools import partial
+
+
+def read_video(name, frame_shape):
+ """
+ Read video which can be:
+ - an image of concatenated frames
+ - '.mp4' and'.gif'
+ - folder with videos
+ """
+
+ if os.path.isdir(name):
+ frames = sorted(os.listdir(name))
+ num_frames = len(frames)
+ video_array = np.array(
+ [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
+ elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
+ image = io.imread(name)
+
+ if len(image.shape) == 2 or image.shape[2] == 1:
+ image = gray2rgb(image)
+
+ if image.shape[2] == 4:
+ image = image[..., :3]
+
+ image = img_as_float32(image)
+
+ video_array = np.moveaxis(image, 1, 0)
+
+ video_array = video_array.reshape((-1,) + frame_shape)
+ video_array = np.moveaxis(video_array, 1, 2)
+ elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
+ video = mimread(name)
+ if len(video[0].shape) == 2:
+ video = [gray2rgb(frame) for frame in video]
+ if frame_shape is not None:
+ video = np.array([resize(frame, frame_shape) for frame in video])
+ video = np.array(video)
+ if video.shape[-1] == 4:
+ video = video[..., :3]
+ video_array = img_as_float32(video)
+ else:
+ raise Exception("Unknown file extensions %s" % name)
+
+ return video_array
+
+
+class FramesDataset(Dataset):
+ """
+ Dataset of videos, each video can be represented as:
+ - an image of concatenated frames
+ - '.mp4' or '.gif'
+ - folder with all frames
+ """
+
+ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
+ random_seed=0, pairs_list=None, augmentation_params=None):
+ self.root_dir = root_dir
+ self.videos = os.listdir(root_dir)
+ self.frame_shape = frame_shape
+ print(self.frame_shape)
+ self.pairs_list = pairs_list
+ self.id_sampling = id_sampling
+
+ if os.path.exists(os.path.join(root_dir, 'train')):
+ assert os.path.exists(os.path.join(root_dir, 'test'))
+ print("Use predefined train-test split.")
+ if id_sampling:
+ train_videos = {os.path.basename(video).split('#')[0] for video in
+ os.listdir(os.path.join(root_dir, 'train'))}
+ train_videos = list(train_videos)
+ else:
+ train_videos = os.listdir(os.path.join(root_dir, 'train'))
+ test_videos = os.listdir(os.path.join(root_dir, 'test'))
+ self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
+ else:
+ print("Use random train-test split.")
+ train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
+
+ if is_train:
+ self.videos = train_videos
+ else:
+ self.videos = test_videos
+
+ self.is_train = is_train
+
+ if self.is_train:
+ self.transform = AllAugmentationTransform(**augmentation_params)
+ else:
+ self.transform = None
+
+ def __len__(self):
+ return len(self.videos)
+
+ def __getitem__(self, idx):
+
+ if self.is_train and self.id_sampling:
+ name = self.videos[idx]
+ path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
+ else:
+ name = self.videos[idx]
+ path = os.path.join(self.root_dir, name)
+
+ video_name = os.path.basename(path)
+ if self.is_train and os.path.isdir(path):
+
+ frames = os.listdir(path)
+ num_frames = len(frames)
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
+
+ if self.frame_shape is not None:
+ resize_fn = partial(resize, output_shape=self.frame_shape)
+ else:
+ resize_fn = img_as_float32
+
+ if type(frames[0]) is bytes:
+ video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in
+ frame_idx]
+ else:
+ video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
+ else:
+
+ video_array = read_video(path, frame_shape=self.frame_shape)
+
+ num_frames = len(video_array)
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
+ num_frames)
+ video_array = video_array[frame_idx]
+
+
+ if self.transform is not None:
+ video_array = self.transform(video_array)
+
+ out = {}
+ if self.is_train:
+ source = np.array(video_array[0], dtype='float32')
+ driving = np.array(video_array[1], dtype='float32')
+
+ out['driving'] = driving.transpose((2, 0, 1))
+ out['source'] = source.transpose((2, 0, 1))
+ else:
+ video = np.array(video_array, dtype='float32')
+ out['video'] = video.transpose((3, 0, 1, 2))
+
+ out['name'] = video_name
+ return out
+
+
+class DatasetRepeater(Dataset):
+ """
+ Pass several times over the same dataset for better i/o performance
+ """
+
+ def __init__(self, dataset, num_repeats=100):
+ self.dataset = dataset
+ self.num_repeats = num_repeats
+
+ def __len__(self):
+ return self.num_repeats * self.dataset.__len__()
+
+ def __getitem__(self, idx):
+ return self.dataset[idx % self.dataset.__len__()]
+
diff --git a/TPSMM/logger.py b/TPSMM/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..825fbe7b20e023db239ead5e4adc06368f78f76d
--- /dev/null
+++ b/TPSMM/logger.py
@@ -0,0 +1,212 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+import imageio
+
+import os
+from skimage.draw import circle
+
+import matplotlib.pyplot as plt
+import collections
+
+
+class Logger:
+ def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
+
+ self.loss_list = []
+ self.cpk_dir = log_dir
+ self.visualizations_dir = os.path.join(log_dir, 'train-vis')
+ if not os.path.exists(self.visualizations_dir):
+ os.makedirs(self.visualizations_dir)
+ self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
+ self.zfill_num = zfill_num
+ self.visualizer = Visualizer(**visualizer_params)
+ self.checkpoint_freq = checkpoint_freq
+ self.epoch = 0
+ self.best_loss = float('inf')
+ self.names = None
+
+ def log_scores(self, loss_names):
+ loss_mean = np.array(self.loss_list).mean(axis=0)
+
+ loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
+ loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
+
+ print(loss_string, file=self.log_file)
+ self.loss_list = []
+ self.log_file.flush()
+
+ def visualize_rec(self, inp, out):
+ image = self.visualizer.visualize(inp['driving'], inp['source'], out)
+ imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
+
+ def save_cpk(self, emergent=False):
+ cpk = {k: v.state_dict() for k, v in self.models.items()}
+ cpk['epoch'] = self.epoch
+ cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
+ if not (os.path.exists(cpk_path) and emergent):
+ torch.save(cpk, cpk_path)
+
+ @staticmethod
+ def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None,
+ bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
+ optimizer_avd=None):
+ checkpoint = torch.load(checkpoint_path)
+ if inpainting_network is not None:
+ inpainting_network.load_state_dict(checkpoint['inpainting_network'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if bg_predictor is not None and 'bg_predictor' in checkpoint:
+ bg_predictor.load_state_dict(checkpoint['bg_predictor'])
+ if dense_motion_network is not None:
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
+ if avd_network is not None:
+ if 'avd_network' in checkpoint:
+ avd_network.load_state_dict(checkpoint['avd_network'])
+ if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint:
+ optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor'])
+ if optimizer is not None and 'optimizer' in checkpoint:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ if optimizer_avd is not None:
+ if 'optimizer_avd' in checkpoint:
+ optimizer_avd.load_state_dict(checkpoint['optimizer_avd'])
+ epoch = -1
+ if 'epoch' in checkpoint:
+ epoch = checkpoint['epoch']
+ return epoch
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if 'models' in self.__dict__:
+ self.save_cpk()
+ self.log_file.close()
+
+ def log_iter(self, losses):
+ losses = collections.OrderedDict(losses.items())
+ self.names = list(losses.keys())
+ self.loss_list.append(list(losses.values()))
+
+ def log_epoch(self, epoch, models, inp, out):
+ self.epoch = epoch
+ self.models = models
+ if (self.epoch + 1) % self.checkpoint_freq == 0:
+ self.save_cpk()
+ self.log_scores(self.names)
+ self.visualize_rec(inp, out)
+
+
+class Visualizer:
+ def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
+ self.kp_size = kp_size
+ self.draw_border = draw_border
+ self.colormap = plt.get_cmap(colormap)
+
+ def draw_image_with_kp(self, image, kp_array):
+ image = np.copy(image)
+ spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
+ kp_array = spatial_size * (kp_array + 1) / 2
+ num_kp = kp_array.shape[0]
+ for kp_ind, kp in enumerate(kp_array):
+ rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
+ image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
+ return image
+
+ def create_image_column_with_kp(self, images, kp):
+ image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
+ return self.create_image_column(image_array)
+
+ def create_image_column(self, images):
+ if self.draw_border:
+ images = np.copy(images)
+ images[:, :, [0, -1]] = (1, 1, 1)
+ images[:, :, [0, -1]] = (1, 1, 1)
+ return np.concatenate(list(images), axis=0)
+
+ def create_image_grid(self, *args):
+ out = []
+ for arg in args:
+ if type(arg) == tuple:
+ out.append(self.create_image_column_with_kp(arg[0], arg[1]))
+ else:
+ out.append(self.create_image_column(arg))
+ return np.concatenate(out, axis=1)
+
+ def visualize(self, driving, source, out):
+ images = []
+
+ # Source image with keypoints
+ source = source.data.cpu()
+ kp_source = out['kp_source']['fg_kp'].data.cpu().numpy()
+ source = np.transpose(source, [0, 2, 3, 1])
+ images.append((source, kp_source))
+
+ # Equivariance visualization
+ if 'transformed_frame' in out:
+ transformed = out['transformed_frame'].data.cpu().numpy()
+ transformed = np.transpose(transformed, [0, 2, 3, 1])
+ transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy()
+ images.append((transformed, transformed_kp))
+
+ # Driving image with keypoints
+ kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy()
+ driving = driving.data.cpu().numpy()
+ driving = np.transpose(driving, [0, 2, 3, 1])
+ images.append((driving, kp_driving))
+
+ # Deformed image
+ if 'deformed' in out:
+ deformed = out['deformed'].data.cpu().numpy()
+ deformed = np.transpose(deformed, [0, 2, 3, 1])
+ images.append(deformed)
+
+ # Result with and without keypoints
+ prediction = out['prediction'].data.cpu().numpy()
+ prediction = np.transpose(prediction, [0, 2, 3, 1])
+ if 'kp_norm' in out:
+ kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy()
+ images.append((prediction, kp_norm))
+ images.append(prediction)
+
+
+ ## Occlusion map
+ if 'occlusion_map' in out:
+ for i in range(len(out['occlusion_map'])):
+ occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1)
+ occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
+ occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
+ images.append(occlusion_map)
+
+ # Deformed images according to each individual transform
+ if 'deformed_source' in out:
+ full_mask = []
+ for i in range(out['deformed_source'].shape[1]):
+ image = out['deformed_source'][:, i].data.cpu()
+ # import ipdb;ipdb.set_trace()
+ image = F.interpolate(image, size=source.shape[1:3])
+ mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
+ mask = F.interpolate(mask, size=source.shape[1:3])
+ image = np.transpose(image.numpy(), (0, 2, 3, 1))
+ mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
+
+ if i != 0:
+ color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3]
+ else:
+ color = np.array((0, 0, 0))
+
+ color = color.reshape((1, 1, 1, 3))
+
+ images.append(image)
+ if i != 0:
+ images.append(mask * color)
+ else:
+ images.append(mask)
+
+ full_mask.append(mask * color)
+
+ images.append(sum(full_mask))
+
+ image = self.create_image_grid(*images)
+ image = (255 * image).astype(np.uint8)
+ return image
diff --git a/TPSMM/modules/avd_network.py b/TPSMM/modules/avd_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..e62937ebc7d00a09f0e10ab9abda038cdaeaaf54
--- /dev/null
+++ b/TPSMM/modules/avd_network.py
@@ -0,0 +1,65 @@
+
+import torch
+from torch import nn
+
+
+class AVDNetwork(nn.Module):
+ """
+ Animation via Disentanglement network
+ """
+
+ def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
+ super(AVDNetwork, self).__init__()
+ input_size = 5*2 * num_tps
+ self.num_tps = num_tps
+
+ self.id_encoder = nn.Sequential(
+ nn.Linear(input_size, 256),
+ nn.BatchNorm1d(256),
+ nn.ReLU(inplace=True),
+ nn.Linear(256, 512),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, 1024),
+ nn.BatchNorm1d(1024),
+ nn.ReLU(inplace=True),
+ nn.Linear(1024, id_bottle_size)
+ )
+
+ self.pose_encoder = nn.Sequential(
+ nn.Linear(input_size, 256),
+ nn.BatchNorm1d(256),
+ nn.ReLU(inplace=True),
+ nn.Linear(256, 512),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, 1024),
+ nn.BatchNorm1d(1024),
+ nn.ReLU(inplace=True),
+ nn.Linear(1024, pose_bottle_size)
+ )
+
+ self.decoder = nn.Sequential(
+ nn.Linear(pose_bottle_size + id_bottle_size, 1024),
+ nn.BatchNorm1d(1024),
+ nn.ReLU(),
+ nn.Linear(1024, 512),
+ nn.BatchNorm1d(512),
+ nn.ReLU(),
+ nn.Linear(512, 256),
+ nn.BatchNorm1d(256),
+ nn.ReLU(),
+ nn.Linear(256, input_size)
+ )
+
+ def forward(self, kp_source, kp_random):
+
+ bs = kp_source['fg_kp'].shape[0]
+
+ pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
+ id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
+
+ rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
+
+ rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
+ return rec
diff --git a/TPSMM/modules/bg_motion_predictor.py b/TPSMM/modules/bg_motion_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..875f4bc5a153545bfa116e0cf42807ce18e01bc2
--- /dev/null
+++ b/TPSMM/modules/bg_motion_predictor.py
@@ -0,0 +1,24 @@
+from torch import nn
+import torch
+from torchvision import models
+
+class BGMotionPredictor(nn.Module):
+ """
+ Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
+ """
+
+ def __init__(self):
+ super(BGMotionPredictor, self).__init__()
+ self.bg_encoder = models.resnet18(pretrained=False)
+ self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
+ num_features = self.bg_encoder.fc.in_features
+ self.bg_encoder.fc = nn.Linear(num_features, 6)
+ self.bg_encoder.fc.weight.data.zero_()
+ self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
+
+ def forward(self, source_image, driving_image):
+ bs = source_image.shape[0]
+ out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
+ prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
+ out[:, :2, :] = prediction.view(bs, 2, 3)
+ return out
diff --git a/TPSMM/modules/dense_motion.py b/TPSMM/modules/dense_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced509a9cc9d435c523b496367862fa0f5974f59
--- /dev/null
+++ b/TPSMM/modules/dense_motion.py
@@ -0,0 +1,164 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
+from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
+import math
+
+class DenseMotionNetwork(nn.Module):
+ """
+ Module that estimating an optical flow and multi-resolution occlusion masks
+ from K TPS transformations and an affine transformation.
+ """
+
+ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
+ scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
+ super(DenseMotionNetwork, self).__init__()
+
+ if scale_factor != 1:
+ self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
+ self.scale_factor = scale_factor
+ self.multi_mask = multi_mask
+
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
+ max_features=max_features, num_blocks=num_blocks)
+
+ hourglass_output_size = self.hourglass.out_channels
+ self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
+
+ if multi_mask:
+ up = []
+ self.up_nums = int(math.log(1/scale_factor, 2))
+ self.occlusion_num = 4
+
+ channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
+ for i in range(self.up_nums):
+ up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
+ self.up = nn.ModuleList(up)
+
+ channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
+ for i in range(self.up_nums):
+ channel.append(hourglass_output_size[-1]//(2**(i+1)))
+ occlusion = []
+
+ for i in range(self.occlusion_num):
+ occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
+ self.occlusion = nn.ModuleList(occlusion)
+ else:
+ occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
+ self.occlusion = nn.ModuleList(occlusion)
+
+ self.num_tps = num_tps
+ self.bg = bg
+ self.kp_variance = kp_variance
+
+
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
+
+ spatial_size = source_image.shape[2:]
+ gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
+ gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
+ heatmap = gaussian_driving - gaussian_source
+
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
+ heatmap = torch.cat([zeros, heatmap], dim=1)
+
+ return heatmap
+
+ def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
+ # K TPS transformaions
+ bs, _, h, w = source_image.shape
+ kp_1 = kp_driving['fg_kp']
+ kp_2 = kp_source['fg_kp']
+ kp_1 = kp_1.view(bs, -1, 5, 2)
+ kp_2 = kp_2.view(bs, -1, 5, 2)
+ trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
+ driving_to_source = trans.transform_frame(source_image)
+
+ identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
+
+ # affine background transformation
+ if not (bg_param is None):
+ identity_grid = to_homogeneous(identity_grid)
+ identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
+ identity_grid = from_homogeneous(identity_grid)
+
+ transformations = torch.cat([identity_grid, driving_to_source], dim=1)
+ return transformations
+
+ def create_deformed_source_image(self, source_image, transformations):
+
+ bs, _, h, w = source_image.shape
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
+ source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
+ transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
+ deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
+ deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
+ return deformed
+
+ def dropout_softmax(self, X, P):
+ '''
+ Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
+ '''
+ drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
+ drop[..., 0] = 1
+ drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
+
+ maxx = X.max(1).values.unsqueeze_(1)
+ X = X - maxx
+ X_exp = X.exp()
+ X[:,1:,...] /= (1-P)
+ mask_bool =(drop == 0)
+ X_exp = X_exp.masked_fill(mask_bool, 0)
+ partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
+ return X_exp / partition
+
+ def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
+ if self.scale_factor != 1:
+ source_image = self.down(source_image)
+
+ bs, _, h, w = source_image.shape
+
+ out_dict = dict()
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
+ transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
+ deformed_source = self.create_deformed_source_image(source_image, transformations)
+ out_dict['deformed_source'] = deformed_source
+ # out_dict['transformations'] = transformations
+ deformed_source = deformed_source.view(bs,-1,h,w)
+ input = torch.cat([heatmap_representation, deformed_source], dim=1)
+ input = input.view(bs, -1, h, w)
+
+ prediction = self.hourglass(input, mode = 1)
+
+ contribution_maps = self.maps(prediction[-1])
+ if(dropout_flag):
+ contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
+ else:
+ contribution_maps = F.softmax(contribution_maps, dim=1)
+ out_dict['contribution_maps'] = contribution_maps
+
+ # Combine the K+1 transformations
+ # Eq(6) in the paper
+ contribution_maps = contribution_maps.unsqueeze(2)
+ transformations = transformations.permute(0, 1, 4, 2, 3)
+ deformation = (transformations * contribution_maps).sum(dim=1)
+ deformation = deformation.permute(0, 2, 3, 1)
+
+ out_dict['deformation'] = deformation # Optical Flow
+
+ occlusion_map = []
+ if self.multi_mask:
+ for i in range(self.occlusion_num-self.up_nums):
+ occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
+ prediction = prediction[-1]
+ for i in range(self.up_nums):
+ prediction = self.up[i](prediction)
+ occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
+ else:
+ occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
+
+ out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
+ return out_dict
diff --git a/TPSMM/modules/inpainting_network.py b/TPSMM/modules/inpainting_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b873bd21e2868b1959be8b92dee8b361ecbd6d8
--- /dev/null
+++ b/TPSMM/modules/inpainting_network.py
@@ -0,0 +1,127 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
+from modules.dense_motion import DenseMotionNetwork
+
+
+class InpaintingNetwork(nn.Module):
+ """
+ Inpaint the missing regions and reconstruct the Driving image.
+ """
+ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
+ super(InpaintingNetwork, self).__init__()
+
+ self.num_down_blocks = num_down_blocks
+ self.multi_mask = multi_mask
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+
+ down_blocks = []
+ up_blocks = []
+ resblock = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ decoder_in_feature = out_features * 2
+ if i==num_down_blocks-1:
+ decoder_in_feature = out_features
+ up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1)))
+ resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
+ resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.up_blocks = nn.ModuleList(up_blocks[::-1])
+ self.resblock = nn.ModuleList(resblock[::-1])
+
+ self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
+ self.num_channels = num_channels
+
+ def deform_input(self, inp, deformation):
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = inp.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
+ deformation = deformation.permute(0, 2, 3, 1)
+ return F.grid_sample(inp, deformation,align_corners=True)
+
+ def occlude_input(self, inp, occlusion_map):
+ if not self.multi_mask:
+ if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
+ out = inp * occlusion_map
+ return out
+
+ def forward(self, source_image, dense_motion):
+ out = self.first(source_image)
+ encoder_map = [out]
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ encoder_map.append(out)
+
+ output_dict = {}
+ output_dict['contribution_maps'] = dense_motion['contribution_maps']
+ output_dict['deformed_source'] = dense_motion['deformed_source']
+
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+
+ deformation = dense_motion['deformation']
+ out_ij = self.deform_input(out.detach(), deformation)
+ out = self.deform_input(out, deformation)
+
+ out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
+ out = self.occlude_input(out, occlusion_map[0])
+
+ warped_encoder_maps = []
+ warped_encoder_maps.append(out_ij)
+
+ for i in range(self.num_down_blocks):
+
+ out = self.resblock[2*i](out)
+ out = self.resblock[2*i+1](out)
+ out = self.up_blocks[i](out)
+
+ encode_i = encoder_map[-(i+2)]
+ encode_ij = self.deform_input(encode_i.detach(), deformation)
+ encode_i = self.deform_input(encode_i, deformation)
+
+ occlusion_ind = 0
+ if self.multi_mask:
+ occlusion_ind = i+1
+ encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
+ encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
+ warped_encoder_maps.append(encode_ij)
+
+ if(i==self.num_down_blocks-1):
+ break
+
+ out = torch.cat([out, encode_i], 1)
+
+ deformed_source = self.deform_input(source_image, deformation)
+ output_dict["deformed"] = deformed_source
+ output_dict["warped_encoder_maps"] = warped_encoder_maps
+
+ occlusion_last = occlusion_map[-1]
+ if not self.multi_mask:
+ occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
+
+ out = out * (1 - occlusion_last) + encode_i
+ out = self.final(out)
+ out = torch.sigmoid(out)
+ out = out * (1 - occlusion_last) + deformed_source * occlusion_last
+ output_dict["prediction"] = out
+
+ return output_dict
+
+ def get_encode(self, driver_image, occlusion_map):
+ out = self.first(driver_image)
+ encoder_map = []
+ encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out.detach())
+ out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
+ encoder_map.append(out_mask.detach())
+
+ return encoder_map
+
diff --git a/TPSMM/modules/keypoint_detector.py b/TPSMM/modules/keypoint_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..a39a19458c75449c65d3e7810974eededb9d2d67
--- /dev/null
+++ b/TPSMM/modules/keypoint_detector.py
@@ -0,0 +1,27 @@
+from torch import nn
+import torch
+from torchvision import models
+
+class KPDetector(nn.Module):
+ """
+ Predict K*5 keypoints.
+ """
+
+ def __init__(self, num_tps, **kwargs):
+ super(KPDetector, self).__init__()
+ self.num_tps = num_tps
+
+ self.fg_encoder = models.resnet18(pretrained=False)
+ num_features = self.fg_encoder.fc.in_features
+ self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
+
+
+ def forward(self, image):
+
+ fg_kp = self.fg_encoder(image)
+ bs, _, = fg_kp.shape
+ fg_kp = torch.sigmoid(fg_kp)
+ fg_kp = fg_kp * 2 - 1
+ out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
+
+ return out
diff --git a/TPSMM/modules/model.py b/TPSMM/modules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8df242b3f33e7c5e01d5e47628a81eb94a3f1964
--- /dev/null
+++ b/TPSMM/modules/model.py
@@ -0,0 +1,182 @@
+from torch import nn
+import torch
+import torch.nn.functional as F
+from modules.util import AntiAliasInterpolation2d, TPS
+from torchvision import models
+import numpy as np
+
+
+class Vgg19(torch.nn.Module):
+ """
+ Vgg19 network for perceptual loss. See Sec 3.3.
+ """
+ def __init__(self, requires_grad=False):
+ super(Vgg19, self).__init__()
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ for x in range(2):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(2, 7):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(7, 12):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(12, 21):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(21, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
+ requires_grad=False)
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
+ requires_grad=False)
+
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ X = (X - self.mean) / self.std
+ h_relu1 = self.slice1(X)
+ h_relu2 = self.slice2(h_relu1)
+ h_relu3 = self.slice3(h_relu2)
+ h_relu4 = self.slice4(h_relu3)
+ h_relu5 = self.slice5(h_relu4)
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
+ return out
+
+
+class ImagePyramide(torch.nn.Module):
+ """
+ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
+ """
+ def __init__(self, scales, num_channels):
+ super(ImagePyramide, self).__init__()
+ downs = {}
+ for scale in scales:
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
+ self.downs = nn.ModuleDict(downs)
+
+ def forward(self, x):
+ out_dict = {}
+ for scale, down_module in self.downs.items():
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
+ return out_dict
+
+
+def detach_kp(kp):
+ return {key: value.detach() for key, value in kp.items()}
+
+
+class GeneratorFullModel(torch.nn.Module):
+ """
+ Merge all generator related updates into single model for better multi-gpu usage
+ """
+
+ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
+ super(GeneratorFullModel, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.inpainting_network = inpainting_network
+ self.dense_motion_network = dense_motion_network
+
+ self.bg_predictor = None
+ if bg_predictor:
+ self.bg_predictor = bg_predictor
+ self.bg_start = train_params['bg_start']
+
+ self.train_params = train_params
+ self.scales = train_params['scales']
+
+ self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
+ if torch.cuda.is_available():
+ self.pyramid = self.pyramid.cuda()
+
+ self.loss_weights = train_params['loss_weights']
+ self.dropout_epoch = train_params['dropout_epoch']
+ self.dropout_maxp = train_params['dropout_maxp']
+ self.dropout_inc_epoch = train_params['dropout_inc_epoch']
+ self.dropout_startp =train_params['dropout_startp']
+
+ if sum(self.loss_weights['perceptual']) != 0:
+ self.vgg = Vgg19()
+ if torch.cuda.is_available():
+ self.vgg = self.vgg.cuda()
+
+
+ def forward(self, x, epoch):
+ kp_source = self.kp_extractor(x['source'])
+ kp_driving = self.kp_extractor(x['driving'])
+ bg_param = None
+ if self.bg_predictor:
+ if(epoch>=self.bg_start):
+ bg_param = self.bg_predictor(x['source'], x['driving'])
+
+ if(epoch>=self.dropout_epoch):
+ dropout_flag = False
+ dropout_p = 0
+ else:
+ # dropout_p will linearly increase from dropout_startp to dropout_maxp
+ dropout_flag = True
+ dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
+
+ dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
+ kp_source=kp_source, bg_param = bg_param,
+ dropout_flag = dropout_flag, dropout_p = dropout_p)
+ generated = self.inpainting_network(x['source'], dense_motion)
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
+
+ loss_values = {}
+
+ pyramide_real = self.pyramid(x['driving'])
+ pyramide_generated = self.pyramid(generated['prediction'])
+
+ # reconstruction loss
+ if sum(self.loss_weights['perceptual']) != 0:
+ value_total = 0
+ for scale in self.scales:
+ x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
+ y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
+
+ for i, weight in enumerate(self.loss_weights['perceptual']):
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
+ value_total += self.loss_weights['perceptual'][i] * value
+ loss_values['perceptual'] = value_total
+
+ # equivariance loss
+ if self.loss_weights['equivariance_value'] != 0:
+ transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
+ transform_grid = transform_random.transform_frame(x['driving'])
+ transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
+ transformed_kp = self.kp_extractor(transformed_frame)
+
+ generated['transformed_frame'] = transformed_frame
+ generated['transformed_kp'] = transformed_kp
+
+ warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
+ kp_d = kp_driving['fg_kp']
+ value = torch.abs(kp_d - warped).mean()
+ loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
+
+ # warp loss
+ if self.loss_weights['warp_loss'] != 0:
+ occlusion_map = generated['occlusion_map']
+ encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
+ decode_map = generated['warped_encoder_maps']
+ value = 0
+ for i in range(len(encode_map)):
+ value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
+
+ loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
+
+ # bg loss
+ if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
+ bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
+ value = torch.matmul(bg_param, bg_param_reverse)
+ eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
+ value = torch.abs(eye - value).mean()
+ loss_values['bg'] = self.loss_weights['bg'] * value
+
+ return loss_values, generated
diff --git a/TPSMM/modules/util.py b/TPSMM/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a869916831d6804624009282196f6b2391dc280
--- /dev/null
+++ b/TPSMM/modules/util.py
@@ -0,0 +1,349 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+
+
+class TPS:
+ '''
+ TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
+ '''
+ def __init__(self, mode, bs, **kwargs):
+ self.bs = bs
+ self.mode = mode
+ if mode == 'random':
+ noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
+ self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
+ self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
+ self.control_points = self.control_points.unsqueeze(0)
+ self.control_params = torch.normal(mean=0,
+ std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
+ elif mode == 'kp':
+ kp_1 = kwargs["kp_1"]
+ kp_2 = kwargs["kp_2"]
+ device = kp_1.device
+ kp_type = kp_1.type()
+ self.gs = kp_1.shape[1]
+ n = kp_1.shape[2]
+ K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
+ K = K**2
+ K = K * torch.log(K+1e-9)
+
+ one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
+ kp_1p = torch.cat([kp_1,one1], 3)
+
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
+ P = torch.cat([kp_1p, zero],2)
+ L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
+ L = torch.cat([L,P],3)
+
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
+ Y = torch.cat([kp_2, zero], 2)
+ one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
+ L = L + one
+
+ param = torch.matmul(torch.inverse(L),Y)
+ self.theta = param[:,:,n:,:].permute(0,1,3,2)
+
+ self.control_points = kp_1
+ self.control_params = param[:,:,:n,:]
+ else:
+ raise Exception("Error TPS mode")
+
+ def transform_frame(self, frame):
+ grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
+ grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
+ shape = [self.bs, frame.shape[2], frame.shape[3], 2]
+ if self.mode == 'kp':
+ shape.insert(1, self.gs)
+ grid = self.warp_coordinates(grid).view(*shape)
+ return grid
+
+ def warp_coordinates(self, coordinates):
+ theta = self.theta.type(coordinates.type()).to(coordinates.device)
+ control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
+ control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
+
+ if self.mode == 'kp':
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
+
+ distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
+
+ distances = distances ** 2
+ result = distances.sum(-1)
+ result = result * torch.log(result + 1e-9)
+ result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
+ transformed = transformed.permute(0, 1, 3, 2) + result
+
+ elif self.mode == 'random':
+ theta = theta.unsqueeze(1)
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
+ transformed = transformed.squeeze(-1)
+ ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
+ distances = ances ** 2
+
+ result = distances.sum(-1)
+ result = result * torch.log(result + 1e-9)
+ result = result * control_params
+ result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
+ transformed = transformed + result
+ else:
+ raise Exception("Error TPS mode")
+
+ return transformed
+
+
+def kp2gaussian(kp, spatial_size, kp_variance):
+ """
+ Transform a keypoint into gaussian like representation
+ """
+
+ coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
+ number_of_leading_dimensions = len(kp.shape) - 1
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
+ coordinate_grid = coordinate_grid.view(*shape)
+ repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
+ coordinate_grid = coordinate_grid.repeat(*repeats)
+
+ # Preprocess kp shape
+ shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
+ kp = kp.view(*shape)
+
+ mean_sub = (coordinate_grid - kp)
+
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
+
+ return out
+
+
+def make_coordinate_grid(spatial_size, type):
+ """
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
+ """
+ h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+
+ return meshed
+
+
+class ResBlock2d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
+ self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class UpBlock2d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock2d, self).__init__()
+
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
+
+ def forward(self, x):
+ out = F.interpolate(x, scale_factor=2)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class DownBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class SameBlock2d(nn.Module):
+ """
+ Simple block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+ kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class Encoder(nn.Module):
+ """
+ Hourglass Encoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Encoder, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ #print('encoder:' ,outs[-1].shape)
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ #print('encoder:' ,outs[-1].shape)
+ return outs
+
+
+class Decoder(nn.Module):
+ """
+ Hourglass Decoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Decoder, self).__init__()
+
+ up_blocks = []
+ self.out_channels = []
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ self.out_channels.append(in_filters)
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+ self.out_channels.append(block_expansion + in_features)
+ # self.out_filters = block_expansion + in_features
+
+ def forward(self, x, mode = 0):
+ out = x.pop()
+ outs = []
+ for up_block in self.up_blocks:
+ out = up_block(out)
+ skip = x.pop()
+ out = torch.cat([out, skip], dim=1)
+ outs.append(out)
+ if(mode == 0):
+ return out
+ else:
+ return outs
+
+
+class Hourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_channels = self.decoder.out_channels
+ # self.out_filters = self.decoder.out_filters
+
+ def forward(self, x, mode = 0):
+ return self.decoder(self.encoder(x), mode)
+
+
+class AntiAliasInterpolation2d(nn.Module):
+ """
+ Band-limited downsampling, for better preservation of the input signal.
+ """
+ def __init__(self, channels, scale):
+ super(AntiAliasInterpolation2d, self).__init__()
+ sigma = (1 / scale - 1) / 2
+ kernel_size = 2 * round(sigma * 4) + 1
+ self.ka = kernel_size // 2
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+ kernel_size = [kernel_size, kernel_size]
+ sigma = [sigma, sigma]
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+ self.scale = scale
+
+ def forward(self, input):
+ if self.scale == 1.0:
+ return input
+
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
+ out = F.interpolate(out, scale_factor=(self.scale, self.scale))
+
+ return out
+
+
+def to_homogeneous(coordinates):
+ ones_shape = list(coordinates.shape)
+ ones_shape[-1] = 1
+ ones = torch.ones(ones_shape).type(coordinates.type())
+
+ return torch.cat([coordinates, ones], dim=-1)
+
+def from_homogeneous(coordinates):
+ return coordinates[..., :2] / coordinates[..., 2:3]
\ No newline at end of file
diff --git a/TPSMM/pkgs/tpsmm.py b/TPSMM/pkgs/tpsmm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5616ef0dc31a43b9229be3f58de73c1c48a8e31
--- /dev/null
+++ b/TPSMM/pkgs/tpsmm.py
@@ -0,0 +1,80 @@
+import os
+import sys
+import torch
+import cv2
+import numpy as np
+from skimage import img_as_ubyte
+from skimage.transform import resize
+pwd = os.path.dirname(os.path.realpath(__file__))
+sys.path.insert(1, os.path.join(pwd, ".."))
+
+from demo import relative_kp, load_checkpoints
+
+
+class TPSMM:
+ def __init__(self):
+ self.device = torch.device("cuda")
+ self.inpainting, self.kp_detector, self.dense_motion_network, self.avd_network = load_checkpoints(
+ config_path=os.path.join(pwd, "../config/vox-256.yaml"),
+ checkpoint_path=os.path.join(pwd, "../pretrained/vox.pth.tar"),
+ device=self.device
+ )
+ self.kp_driving_initial = None
+
+ def process_source(self, src_img):
+ with torch.no_grad():
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
+ src_img = resize(src_img, (256, 256))
+ source_tensor = torch.tensor(src_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device)
+ kp_source = self.kp_detector(source_tensor)
+
+ return source_tensor, kp_source
+
+
+ def gen_image(self, driving_img, source_tensor, kp_source):
+ with torch.no_grad():
+ driving_img = cv2.cvtColor(driving_img, cv2.COLOR_BGR2RGB)
+ driving_img = resize(driving_img, (256, 256))
+ driving_frame = torch.tensor(driving_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device)
+
+ kp_driving = self.kp_detector(driving_frame)
+ if self.kp_driving_initial is None:
+ self.kp_driving_initial = kp_driving
+ kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
+ kp_driving_initial=self.kp_driving_initial)
+ dense_motion = self.dense_motion_network(source_image=source_tensor,
+ kp_driving=kp_norm,
+ kp_source=kp_source, bg_param=None,
+ dropout_flag=False)
+ out = self.inpainting(source_tensor, dense_motion)
+ out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
+ out = img_as_ubyte(out)
+ out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
+
+ return out
+
+
+if __name__ == "__main__":
+ tpsmm = TPSMM()
+ source_image = cv2.imread(os.path.join(pwd, "../assets/source1.png"))
+ cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4")
+
+ source_tensor, kp_source = tpsmm.process_source(source_image)
+
+ while True:
+ ret, frame = cap.read()
+ if frame is None:
+ break
+
+ output = tpsmm.gen_image(frame, source_tensor, kp_source)
+ cv2.imshow("output", output)
+ key = cv2.waitKey(1) & 0xFF
+ if key == ord("q"):
+ break
+ # cv2.imwrite("./tmp.jpg", output)
+
+ cv2.destroyAllWindows()
+
+
+
+
\ No newline at end of file
diff --git a/TPSMM/predict.py b/TPSMM/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..d615236a2d41a4f18bcba969f74e89cd3df05f50
--- /dev/null
+++ b/TPSMM/predict.py
@@ -0,0 +1,125 @@
+import os
+import sys
+sys.path.insert(0, "stylegan-encoder")
+import tempfile
+import warnings
+import imageio
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+from skimage.transform import resize
+from skimage import img_as_ubyte
+import torch
+import torchvision.transforms as transforms
+import dlib
+from cog import BasePredictor, Path, Input
+
+from demo import load_checkpoints
+from demo import make_animation
+from ffhq_dataset.face_alignment import image_align
+from ffhq_dataset.landmarks_detector import LandmarksDetector
+
+
+warnings.filterwarnings("ignore")
+
+
+PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
+LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")
+
+
+class Predictor(BasePredictor):
+ def setup(self):
+
+ self.device = torch.device("cuda:0")
+ datasets = ["vox", "taichi", "ted", "mgif"]
+ (
+ self.inpainting,
+ self.kp_detector,
+ self.dense_motion_network,
+ self.avd_network,
+ ) = ({}, {}, {}, {})
+ for d in datasets:
+ (
+ self.inpainting[d],
+ self.kp_detector[d],
+ self.dense_motion_network[d],
+ self.avd_network[d],
+ ) = load_checkpoints(
+ config_path=f"config/{d}-384.yaml"
+ if d == "ted"
+ else f"config/{d}-256.yaml",
+ checkpoint_path=f"checkpoints/{d}.pth.tar",
+ device=self.device,
+ )
+
+ def predict(
+ self,
+ source_image: Path = Input(
+ description="Input source image.",
+ ),
+ driving_video: Path = Input(
+ description="Choose a micromotion.",
+ ),
+ dataset_name: str = Input(
+ choices=["vox", "taichi", "ted", "mgif"],
+ default="vox",
+ description="Choose a dataset.",
+ ),
+ ) -> Path:
+
+ predict_mode = "relative" # ['standard', 'relative', 'avd']
+ # find_best_frame = False
+
+ pixel = 384 if dataset_name == "ted" else 256
+
+ if dataset_name == "vox":
+ # first run face alignment
+ align_image(str(source_image), 'aligned.png')
+ source_image = imageio.imread('aligned.png')
+ else:
+ source_image = imageio.imread(str(source_image))
+ reader = imageio.get_reader(str(driving_video))
+ fps = reader.get_meta_data()["fps"]
+ source_image = resize(source_image, (pixel, pixel))[..., :3]
+
+ driving_video = []
+ try:
+ for im in reader:
+ driving_video.append(im)
+ except RuntimeError:
+ pass
+ reader.close()
+
+ driving_video = [
+ resize(frame, (pixel, pixel))[..., :3] for frame in driving_video
+ ]
+
+ inpainting, kp_detector, dense_motion_network, avd_network = (
+ self.inpainting[dataset_name],
+ self.kp_detector[dataset_name],
+ self.dense_motion_network[dataset_name],
+ self.avd_network[dataset_name],
+ )
+
+ predictions = make_animation(
+ source_image,
+ driving_video,
+ inpainting,
+ kp_detector,
+ dense_motion_network,
+ avd_network,
+ device="cuda:0",
+ mode=predict_mode,
+ )
+
+ # save resulting video
+ out_path = Path(tempfile.mkdtemp()) / "output.mp4"
+ imageio.mimsave(
+ str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps
+ )
+ return out_path
+
+
+def align_image(raw_img_path, aligned_face_path):
+ for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
+ image_align(raw_img_path, aligned_face_path, face_landmarks)
diff --git a/TPSMM/pretrained/vox.pth.tar b/TPSMM/pretrained/vox.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..5cb58244930878bdde245b9fce6eda59ac5fbe9b
--- /dev/null
+++ b/TPSMM/pretrained/vox.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52ad8c848e2a1d91b621de96fea83faf57ce3b8c1c06424e317f4df1d3998204
+size 350993469
diff --git a/TPSMM/reconstruction.py b/TPSMM/reconstruction.py
new file mode 100644
index 0000000000000000000000000000000000000000..40d4cf466339aa87935b3d488f759a066d753a4e
--- /dev/null
+++ b/TPSMM/reconstruction.py
@@ -0,0 +1,69 @@
+import os
+from tqdm import tqdm
+import torch
+from torch.utils.data import DataLoader
+from logger import Logger, Visualizer
+import numpy as np
+import imageio
+
+
+def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
+ png_dir = os.path.join(log_dir, 'reconstruction/png')
+ log_dir = os.path.join(log_dir, 'reconstruction')
+
+ if checkpoint is not None:
+ Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
+ bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
+ else:
+ raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
+
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+ if not os.path.exists(png_dir):
+ os.makedirs(png_dir)
+
+ loss_list = []
+
+ inpainting_network.eval()
+ kp_detector.eval()
+ dense_motion_network.eval()
+ if bg_predictor:
+ bg_predictor.eval()
+
+ for it, x in tqdm(enumerate(dataloader)):
+ with torch.no_grad():
+ predictions = []
+ visualizations = []
+ if torch.cuda.is_available():
+ x['video'] = x['video'].cuda()
+ kp_source = kp_detector(x['video'][:, :, 0])
+ for frame_idx in range(x['video'].shape[2]):
+ source = x['video'][:, :, 0]
+ driving = x['video'][:, :, frame_idx]
+ kp_driving = kp_detector(driving)
+ bg_params = None
+ if bg_predictor:
+ bg_params = bg_predictor(source, driving)
+
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
+ kp_source=kp_source, bg_param = bg_params,
+ dropout_flag = False)
+ out = inpainting_network(source, dense_motion)
+ out['kp_source'] = kp_source
+ out['kp_driving'] = kp_driving
+
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
+
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
+ driving=driving, out=out)
+ visualizations.append(visualization)
+ loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
+
+ loss_list.append(loss)
+ # print(np.mean(loss_list))
+ predictions = np.concatenate(predictions, axis=1)
+ imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
+
+ print("Reconstruction loss: %s" % np.mean(loss_list))
diff --git a/TPSMM/requirements.txt b/TPSMM/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..20beb09852a9ae07b1d6cae6567b2b61dcf33add
--- /dev/null
+++ b/TPSMM/requirements.txt
@@ -0,0 +1,25 @@
+cffi==1.14.6
+cycler==0.10.0
+decorator==5.1.0
+face-alignment==1.3.5
+imageio==2.9.0
+imageio-ffmpeg==0.4.5
+kiwisolver==1.3.2
+matplotlib==3.4.3
+networkx==2.6.3
+numpy==1.20.3
+pandas==1.3.3
+Pillow==8.3.2
+pycparser==2.20
+pyparsing==2.4.7
+python-dateutil==2.8.2
+pytz==2021.1
+PyWavelets==1.1.1
+PyYAML==5.4.1
+scikit-image==0.18.3
+scikit-learn==1.0
+scipy==1.7.1
+six==1.16.0
+torch==1.10.0+cu113
+torchvision==0.11.0+cu113
+tqdm==4.62.3
\ No newline at end of file
diff --git a/TPSMM/run.py b/TPSMM/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..6120213fe79c670212b2fc79e0ddb105fb178c45
--- /dev/null
+++ b/TPSMM/run.py
@@ -0,0 +1,89 @@
+import matplotlib
+matplotlib.use('Agg')
+
+import os, sys
+import yaml
+from argparse import ArgumentParser
+from time import gmtime, strftime
+from shutil import copy
+from frames_dataset import FramesDataset
+
+from modules.inpainting_network import InpaintingNetwork
+from modules.keypoint_detector import KPDetector
+from modules.bg_motion_predictor import BGMotionPredictor
+from modules.dense_motion import DenseMotionNetwork
+from modules.avd_network import AVDNetwork
+import torch
+from train import train
+from train_avd import train_avd
+from reconstruction import reconstruction
+import os
+
+
+if __name__ == "__main__":
+
+ if sys.version_info[0] < 3:
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
+
+ parser = ArgumentParser()
+ parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
+ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
+ parser.add_argument("--log_dir", default='log', help="path to log into")
+ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
+ parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))),
+ help="Names of the devices comma separated.")
+
+ opt = parser.parse_args()
+ with open(opt.config) as f:
+ config = yaml.load(f)
+
+ if opt.checkpoint is not None:
+ log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
+ else:
+ log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
+ log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
+
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+
+ if torch.cuda.is_available():
+ cuda_device = torch.device('cuda:'+str(opt.device_ids[0]))
+ inpainting.to(cuda_device)
+
+ kp_detector = KPDetector(**config['model_params']['common_params'])
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
+ **config['model_params']['dense_motion_params'])
+
+ if torch.cuda.is_available():
+ kp_detector.to(opt.device_ids[0])
+ dense_motion_network.to(opt.device_ids[0])
+
+ bg_predictor = None
+ if (config['model_params']['common_params']['bg']):
+ bg_predictor = BGMotionPredictor()
+ if torch.cuda.is_available():
+ bg_predictor.to(opt.device_ids[0])
+
+ avd_network = None
+ if opt.mode == "train_avd":
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
+ **config['model_params']['avd_network_params'])
+ if torch.cuda.is_available():
+ avd_network.to(opt.device_ids[0])
+
+ dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params'])
+
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+ if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
+ copy(opt.config, log_dir)
+
+ if opt.mode == 'train':
+ print("Training...")
+ train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
+ elif opt.mode == 'train_avd':
+ print("Training Animation via Disentaglement...")
+ train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset)
+ elif opt.mode == 'reconstruction':
+ print("Reconstruction...")
+ reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
diff --git a/TPSMM/tmp.jpg b/TPSMM/tmp.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cd953a5825c6a18bd2b1d6d6ace77536ee0cc894
Binary files /dev/null and b/TPSMM/tmp.jpg differ
diff --git a/TPSMM/tmp.py b/TPSMM/tmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c91bbf096307c9b0ba4168ed65b79af70adc0bc
--- /dev/null
+++ b/TPSMM/tmp.py
@@ -0,0 +1,14 @@
+import cv2
+
+
+cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4")
+while True:
+ ret, frame = cap.read()
+ if frame is None:
+ break
+ cv2.imshow("output", frame)
+ key = cv2.waitKey(1) & 0xff
+ if key == ord("q"):
+ break
+
+cv2.destroyAllWindows()
\ No newline at end of file
diff --git a/TPSMM/train.py b/TPSMM/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..06ce3be20bc4fcbc5395c596b042c1bf2bdad8b8
--- /dev/null
+++ b/TPSMM/train.py
@@ -0,0 +1,94 @@
+from tqdm import trange
+import torch
+from torch.utils.data import DataLoader
+from logger import Logger
+from modules.model import GeneratorFullModel
+from torch.optim.lr_scheduler import MultiStepLR
+from torch.nn.utils import clip_grad_norm_
+from frames_dataset import DatasetRepeater
+import math
+
+def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
+ train_params = config['train_params']
+ optimizer = torch.optim.Adam(
+ [{'params': list(inpainting_network.parameters()) +
+ list(dense_motion_network.parameters()) +
+ list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
+
+ optimizer_bg_predictor = None
+ if bg_predictor:
+ optimizer_bg_predictor = torch.optim.Adam(
+ [{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}],
+ lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
+
+ if checkpoint is not None:
+ start_epoch = Logger.load_cpk(
+ checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network,
+ kp_detector = kp_detector, bg_predictor = bg_predictor,
+ optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor)
+ print('load success:', start_epoch)
+ start_epoch += 1
+ else:
+ start_epoch = 0
+
+ scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1,
+ last_epoch=start_epoch - 1)
+ if bg_predictor:
+ scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'],
+ gamma=0.1, last_epoch=start_epoch - 1)
+
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
+ num_workers=train_params['dataloader_workers'], drop_last=True)
+
+ generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params)
+
+ if torch.cuda.is_available():
+ generator_full = torch.nn.DataParallel(generator_full).cuda()
+
+ bg_start = train_params['bg_start']
+
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
+ checkpoint_freq=train_params['checkpoint_freq']) as logger:
+ for epoch in trange(start_epoch, train_params['num_epochs']):
+ for x in dataloader:
+ if(torch.cuda.is_available()):
+ x['driving'] = x['driving'].cuda()
+ x['source'] = x['source'].cuda()
+
+ losses_generator, generated = generator_full(x, epoch)
+ loss_values = [val.mean() for val in losses_generator.values()]
+ loss = sum(loss_values)
+ loss.backward()
+
+ clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf)
+ clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf)
+ if bg_predictor and epoch>=bg_start:
+ clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf)
+
+ optimizer.step()
+ optimizer.zero_grad()
+ if bg_predictor and epoch>=bg_start:
+ optimizer_bg_predictor.step()
+ optimizer_bg_predictor.zero_grad()
+
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
+ logger.log_iter(losses=losses)
+
+ scheduler_optimizer.step()
+ if bg_predictor:
+ scheduler_bg_predictor.step()
+
+ model_save = {
+ 'inpainting_network': inpainting_network,
+ 'dense_motion_network': dense_motion_network,
+ 'kp_detector': kp_detector,
+ 'optimizer': optimizer,
+ }
+ if bg_predictor and epoch>=bg_start:
+ model_save['bg_predictor'] = bg_predictor
+ model_save['optimizer_bg_predictor'] = optimizer_bg_predictor
+
+ logger.log_epoch(epoch, model_save, inp=x, out=generated)
+
diff --git a/TPSMM/train_avd.py b/TPSMM/train_avd.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc6794c322e01d980acc2f3b3aab2b192576900d
--- /dev/null
+++ b/TPSMM/train_avd.py
@@ -0,0 +1,91 @@
+from tqdm import trange
+import torch
+from torch.utils.data import DataLoader
+from logger import Logger
+from torch.optim.lr_scheduler import MultiStepLR
+from frames_dataset import DatasetRepeater
+
+
+def random_scale(kp_params, scale):
+ theta = torch.rand(kp_params['fg_kp'].shape[0], 2) * (2 * scale) + (1 - scale)
+ theta = torch.diag_embed(theta).unsqueeze(1).type(kp_params['fg_kp'].type())
+ new_kp_params = {'fg_kp': torch.matmul(theta, kp_params['fg_kp'].unsqueeze(-1)).squeeze(-1)}
+ return new_kp_params
+
+
+def train_avd(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network,
+ avd_network, checkpoint, log_dir, dataset):
+ train_params = config['train_avd_params']
+
+ optimizer = torch.optim.Adam(avd_network.parameters(), lr=train_params['lr'], betas=(0.5, 0.999))
+
+ if checkpoint is not None:
+ Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
+ bg_predictor=bg_predictor, avd_network=avd_network,
+ dense_motion_network= dense_motion_network,optimizer_avd=optimizer)
+ start_epoch = 0
+ else:
+ raise AttributeError("Checkpoint should be specified for mode='train_avd'.")
+
+ scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1)
+
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
+
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
+ num_workers=train_params['dataloader_workers'], drop_last=True)
+
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
+ checkpoint_freq=train_params['checkpoint_freq']) as logger:
+ for epoch in trange(start_epoch, train_params['num_epochs']):
+ avd_network.train()
+ for x in dataloader:
+ with torch.no_grad():
+ kp_source = kp_detector(x['source'].cuda())
+ kp_driving_gt = kp_detector(x['driving'].cuda())
+ kp_driving_random = random_scale(kp_driving_gt, scale=train_params['random_scale'])
+ rec = avd_network(kp_source, kp_driving_random)
+
+ reconstruction_kp = train_params['lambda_shift'] * \
+ torch.abs(kp_driving_gt['fg_kp'] - rec['fg_kp']).mean()
+
+ loss_dict = {'rec_kp': reconstruction_kp}
+ loss = reconstruction_kp
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in loss_dict.items()}
+ logger.log_iter(losses=losses)
+
+ # Visualization
+ avd_network.eval()
+ with torch.no_grad():
+ source = x['source'][:6].cuda()
+ driving = torch.cat([x['driving'][[0, 1]].cuda(), source[[2, 3, 2, 1]]], dim=0)
+ kp_source = kp_detector(source)
+ kp_driving = kp_detector(driving)
+
+ out = avd_network(kp_source, kp_driving)
+ kp_driving = out
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
+ kp_source=kp_source)
+ generated = inpainting_network(source, dense_motion)
+
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
+
+ scheduler.step(epoch)
+ model_save = {
+ 'inpainting_network': inpainting_network,
+ 'dense_motion_network': dense_motion_network,
+ 'kp_detector': kp_detector,
+ 'avd_network': avd_network,
+ 'optimizer_avd': optimizer
+ }
+ if bg_predictor :
+ model_save['bg_predictor'] = bg_predictor
+
+ logger.log_epoch(epoch, model_save,
+ inp={'source': source, 'driving': driving},
+ out=generated)
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ab3f624f38c957738523167a9d414322e98f1b
--- /dev/null
+++ b/app.py
@@ -0,0 +1,122 @@
+import av
+import sys
+import numpy as np
+import cv2
+import streamlit as st
+from PIL import Image
+from streamlit_webrtc import WebRtcMode, webrtc_streamer
+
+sys.path.insert(1, "./retinaface")
+sys.path.insert(1, "./TPSMM/pkgs")
+from tpsmm import TPSMM
+from detect import Detect
+from turn import get_ice_servers
+
+
+def parse_roi_box_from_bbox(bbox, shape):
+ img_h, img_w = shape[:2]
+ left, top, right, bottom = bbox[:4]
+ old_size = (right - left + bottom - top) / 2
+ center_x = right - (right - left) / 2.0
+ center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14
+
+ size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0)
+
+ roi_box = [0] * 4
+ roi_box[0] = center_x - size / 2
+ roi_box[1] = center_y - size / 2
+ roi_box[2] = roi_box[0] + size
+ roi_box[3] = roi_box[1] + size
+
+ return roi_box
+
+cache_key = "retinaface"
+if cache_key in st.session_state:
+ detector = st.session_state[cache_key]
+else:
+ detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864))
+ st.session_state[cache_key] = detector
+
+cache_key = "tpsmm"
+if cache_key in st.session_state:
+ generator = st.session_state[cache_key]
+else:
+ generator = TPSMM()
+ st.session_state[cache_key] = generator
+
+
+@st.cache_resource # type: ignore
+def get_images():
+ images = [
+ cv2.imread("assets/0.jpg"),
+ cv2.imread("assets/1.jpg"),
+ cv2.imread("assets/2.jpg"),
+ cv2.imread("assets/3.jpg"),
+ ]
+ item_list = [str(i) for i in range(len(images))]
+ images = [generator.process_source(src_img) for src_img in images]
+
+ return dict(zip(item_list, images))
+images = get_images()
+user_option = st.selectbox("Choose an item", list(images.keys()))
+
+uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg'])
+@st.cache_resource
+def process_file(uploaded_file):
+ img = Image.open(uploaded_file)
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ dets = detector(img)
+ for i, b in enumerate(dets):
+ bbox = parse_roi_box_from_bbox(b[:4], img.shape)
+ bbox = [int(i) for i in bbox]
+
+ face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
+ # cv2.imwrite("./tmp.jpg", face_img)
+ return generator.process_source(face_img)
+
+ return None
+if uploaded_file is not None:
+ uploaded_file = process_file(uploaded_file)
+
+def callback(frame: av.VideoFrame) -> av.VideoFrame:
+ img = frame.to_ndarray(format="bgr24")
+
+ try:
+ dets = detector(img)
+ output = None
+ for i, b in enumerate(dets):
+ text = "{:.4f}".format(b[4])
+ b = b.astype(np.int32)
+ cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
+ bbox = parse_roi_box_from_bbox(b[:4], img.shape)
+ bbox = [int(i) for i in bbox]
+ cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
+
+ face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
+ if uploaded_file is None:
+ source_tensor, kp_source = images[user_option]
+ else:
+ source_tensor, kp_source = uploaded_file
+ output = generator.gen_image(face_img, source_tensor, kp_source)
+
+ landm = b[5:15]
+ landm = landm.reshape((5, 2))
+ cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
+ cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
+ cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
+ cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
+ cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
+
+ if output is not None:
+ img[:256, :256] = output
+ except Exception as e:
+ print(e)
+
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
+
+webrtc_streamer(
+ key="sample",
+ rtc_configuration={"iceServers": get_ice_servers()},
+ video_frame_callback=callback,
+ media_stream_constraints={"video": True, "audio": False},
+)
\ No newline at end of file
diff --git a/assets/0.jpg b/assets/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..852cb6d7ab304bcac7f18e1f0fb7d1a229e7f502
Binary files /dev/null and b/assets/0.jpg differ
diff --git a/assets/1.jpg b/assets/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..20a6e64c8703e54b943fb64a48f63ec90cbb9df1
Binary files /dev/null and b/assets/1.jpg differ
diff --git a/assets/2.jpg b/assets/2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9132972dceff4c0d9fd42421b07e3c4d36a94c97
Binary files /dev/null and b/assets/2.jpg differ
diff --git a/assets/3.jpg b/assets/3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8636aeea387fe8c32f375d21ab12ad02753c707a
Binary files /dev/null and b/assets/3.jpg differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c6537c213c1503c83a083c54b142c4eb0ea3ce5d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+streamlit-webrtc
+twilio
+altair<5
+numpy==1.23.1
+opencv-python==4.8.0.74
+imutils
+scikit-image==0.21.0
+matplotlib==3.7.1
+pyaml==23.5.9
+tqdm
+torch
+torchvision
\ No newline at end of file
diff --git a/retinaface/change_batch_onnx.py b/retinaface/change_batch_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..68cfc02981ea068eaf034ac32c4cb20e9214f5d4
--- /dev/null
+++ b/retinaface/change_batch_onnx.py
@@ -0,0 +1,43 @@
+import onnx
+
+
+model = onnx.load('weights/faceDetector_243_432_b1_sim.onnx')
+
+# # for fixed batchsize
+# model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 32
+# model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 32
+# model.graph.output[1].type.tensor_type.shape.dim[0].dim_value = 32
+# model.graph.output[2].type.tensor_type.shape.dim[0].dim_value = 32
+# onnx.save(model, 'faceDetector_640_b32.onnx')
+
+
+model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'batch' # for dynamic batchsize
+model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[1].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[2].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[3].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[4].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[5].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[6].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[7].type.tensor_type.shape.dim[0].dim_param = 'batch'
+model.graph.output[8].type.tensor_type.shape.dim[0].dim_param = 'batch'
+onnx.save(model, 'weights/faceDetector_243_432_batch_sim.onnx')
+
+
+####################################################
+# SHow model onnx
+
+# import onnxruntime as rt
+# ort_session = rt.InferenceSession("faceDetector_180_320_batch_sim.onnx")
+# print("====INPUT====")
+# for i in ort_session.get_inputs():
+# print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type))
+# print("====OUTPUT====")
+# for i in ort_session.get_outputs():
+# print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type))
+
+# import numpy as np
+# input_name = ort_session.get_inputs()[0].name
+# img = np.random.randn(4, 3, 180, 320).astype(np.float32)
+# data = ort_session.run(None, {input_name: img})
+# print("Done")
\ No newline at end of file
diff --git a/retinaface/convert_to_onnx.py b/retinaface/convert_to_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..da98249ca3412ad2124efea2bc909c6586152038
--- /dev/null
+++ b/retinaface/convert_to_onnx.py
@@ -0,0 +1,135 @@
+# python convert_to_onnx.py --network mobile0.25 --trained_model weights/mobilenet0.25_Final.pth
+from __future__ import print_function
+import os
+import argparse
+import torch
+import torch.backends.cudnn as cudnn
+import numpy as np
+from data import cfg_mnet, cfg_slim, cfg_rfb
+from layers.functions.prior_box import PriorBox
+from utils.nms.py_cpu_nms import py_cpu_nms
+import cv2
+from models.retinaface import RetinaFace
+from models.net_slim import Slim
+from models.net_rfb import RFB
+from utils.box_utils import decode, decode_landm
+from utils.timer import Timer
+
+
+parser = argparse.ArgumentParser(description='Test')
+parser.add_argument('-m', '--trained_model', default='./weights/RBF_Final.pth',
+ type=str, help='Trained state_dict file path to open')
+parser.add_argument('--network', default='RFB', help='Backbone network mobile0.25 or slim or RFB')
+parser.add_argument('--long_side', default=320, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)')
+parser.add_argument('--cpu', action="store_true", help='Use cpu inference')
+
+args = parser.parse_args()
+
+
+def check_keys(model, pretrained_state_dict):
+ ckpt_keys = set(pretrained_state_dict.keys())
+ model_keys = set(model.state_dict().keys())
+ used_pretrained_keys = model_keys & ckpt_keys
+ unused_pretrained_keys = ckpt_keys - model_keys
+ missing_keys = model_keys - ckpt_keys
+ print('Missing keys:{}'.format(len(missing_keys)))
+ print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
+ print('Used keys:{}'.format(len(used_pretrained_keys)))
+ assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
+ return True
+
+
+def remove_prefix(state_dict, prefix):
+ ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
+ print('remove prefix \'{}\''.format(prefix))
+ f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
+ return {f(key): value for key, value in state_dict.items()}
+
+
+def load_model(model, pretrained_path, load_to_cpu):
+ print('Loading pretrained model from {}'.format(pretrained_path))
+ if load_to_cpu:
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
+ else:
+ device = torch.cuda.current_device()
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
+ if "state_dict" in pretrained_dict.keys():
+ pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
+ else:
+ pretrained_dict = remove_prefix(pretrained_dict, 'module.')
+ check_keys(model, pretrained_dict)
+ model.load_state_dict(pretrained_dict, strict=False)
+ return model
+
+
+if __name__ == '__main__':
+ torch.set_grad_enabled(False)
+
+ cfg = None
+ net = None
+ # long_side = int(args.long_side)
+ net_inshape = (243, 432)
+ device = torch.device("cpu" if args.cpu else "cuda")
+ print(device)
+ if args.network == "mobile0.25":
+ cfg = cfg_mnet
+ # net_inshape = (long_side, long_side) # h, w
+ priorbox = PriorBox(cfg, image_size=net_inshape)
+ priors = priorbox.forward()
+ prior_data = priors.to(device)
+ net = RetinaFace(cfg=cfg, phase='test')
+ elif args.network == "slim":
+ cfg = cfg_slim
+ net = Slim(cfg = cfg, phase = 'test')
+ elif args.network == "RFB":
+ cfg = cfg_rfb
+ net = RFB(cfg = cfg, phase = 'test')
+ else:
+ print("Don't support network!")
+ exit(0)
+
+ # load weight
+ net = load_model(net, args.trained_model, args.cpu)
+ net.eval()
+ print('Finished loading model!')
+ print(net)
+ net = net.to(device)
+
+ ##################export###############
+ output_onnx = f'weights/faceDetector_{net_inshape[0]}_{net_inshape[1]}_b1.onnx'
+ print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
+ input_names = ['input_1']
+ output_names = ['box_1', 'box_2', 'box_3']
+
+ # import torch.onnx.symbolic_opset9 as onnx_symbolic
+ # def upsample_nearest2d(g, input, output_size, *args):
+ # # Currently, TRT 5.1/6.0/7.0 ONNX Parser does not support all ONNX ops
+ # # needed to support dynamic upsampling ONNX forumlation
+ # # Here we hardcode scale=2 as a temporary workaround
+ # scales = g.op("Constant", value_t=torch.tensor([1., 1., 2., 2.]))
+ # return g.op("Resize", input, scales, mode_s="nearest")
+
+
+ # onnx_symbolic.upsample_nearest2d = upsample_nearest2d
+
+ # import io
+ # onnx_bytes = io.BytesIO()
+ # zero_input = torch.zeros([1, 3, net_inshape[0], net_inshape[1]]).cuda()
+ # dynamic_axes = {input_names[0]: {0:'batch'}}
+ # for _, name in enumerate(output_names):
+ # dynamic_axes[name] = dynamic_axes[input_names[0]]
+ # extra_args = {'opset_version': 10, 'verbose': False,
+ # 'input_names': input_names, 'output_names': output_names,
+ # 'dynamic_axes': dynamic_axes}
+ # torch.onnx.export(net, zero_input, onnx_bytes, **extra_args)
+ # with open(output_onnx, 'wb') as out:
+ # out.write(onnx_bytes.getvalue())
+
+ inputs = torch.randn(1, 3, net_inshape[0], net_inshape[1]).to(device)
+ torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False, opset_version=9,
+ input_names=input_names, output_names=output_names)
+ ################end###############
+
+
+
+
diff --git a/retinaface/data/__init__.py b/retinaface/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea50ebaf88d64e75f4960bc99b14f138a343e575
--- /dev/null
+++ b/retinaface/data/__init__.py
@@ -0,0 +1,3 @@
+from .wider_face import WiderFaceDetection, detection_collate
+from .data_augment import *
+from .config import *
diff --git a/retinaface/data/config.py b/retinaface/data/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d22bdfc205e07b7e88590526c72421d3f54d0ec
--- /dev/null
+++ b/retinaface/data/config.py
@@ -0,0 +1,55 @@
+# config.py
+cfg_mnet = {
+ 'name': 'mobilenet0.25',
+ 'min_sizes': [[10, 20], [32, 64], [128, 256]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 32,
+ 'ngpu': 1,
+ 'epoch': 250,
+ 'decay1': 190,
+ 'decay2': 220,
+ 'image_size': 300,
+ "net_inshape": (320, 320), # h, w
+ 'pretrain': False,
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
+ 'in_channel': 32,
+ 'out_channel': 64
+}
+
+cfg_slim = {
+ 'name': 'slim',
+ 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
+ 'steps': [8, 16, 32, 64],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 32,
+ 'ngpu': 1,
+ 'epoch': 250,
+ 'decay1': 190,
+ 'decay2': 220,
+ 'image_size': 300
+}
+
+cfg_rfb = {
+ 'name': 'RFB',
+ 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]],
+ 'steps': [8, 16, 32, 64],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 32,
+ 'ngpu': 1,
+ 'epoch': 250,
+ 'decay1': 190,
+ 'decay2': 220,
+ 'image_size': 300
+}
+
+
diff --git a/retinaface/data/data_augment.py b/retinaface/data/data_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..33ec89d8382b758d0ca26520d422c6bd730c4cb2
--- /dev/null
+++ b/retinaface/data/data_augment.py
@@ -0,0 +1,235 @@
+import cv2
+import numpy as np
+import random
+from utils.box_utils import matrix_iof
+
+
+def _crop(image, boxes, labels, landm, img_dim):
+ height, width, _ = image.shape
+ pad_image_flag = True
+
+ for _ in range(250):
+ if random.uniform(0, 1) <= 0.2:
+ scale = 1.0
+ else:
+ scale = random.uniform(0.3, 1.0)
+ # PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
+ # scale = random.choice(PRE_SCALES)
+ short_side = min(width, height)
+ w = int(scale * short_side)
+ h = w
+
+ if width == w:
+ l = 0
+ else:
+ l = random.randrange(width - w)
+ if height == h:
+ t = 0
+ else:
+ t = random.randrange(height - h)
+ roi = np.array((l, t, l + w, t + h))
+
+ value = matrix_iof(boxes, roi[np.newaxis])
+ flag = (value >= 1)
+ if not flag.any():
+ continue
+
+ centers = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
+ boxes_t = boxes[mask_a].copy()
+ labels_t = labels[mask_a].copy()
+ landms_t = landm[mask_a].copy()
+ landms_t = landms_t.reshape([-1, 5, 2])
+
+ if boxes_t.shape[0] == 0:
+ continue
+
+ image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
+
+ boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
+ boxes_t[:, :2] -= roi[:2]
+ boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
+ boxes_t[:, 2:] -= roi[:2]
+
+ # landm
+ landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
+ landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
+ landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
+ landms_t = landms_t.reshape([-1, 10])
+
+
+ # make sure that the cropped image contains at least one face > 16 pixel at training image scale
+ b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
+ b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
+ mask_b = np.minimum(b_w_t, b_h_t) > 5
+ boxes_t = boxes_t[mask_b]
+ labels_t = labels_t[mask_b]
+ landms_t = landms_t[mask_b]
+
+ if boxes_t.shape[0] == 0:
+ continue
+
+ pad_image_flag = False
+
+ return image_t, boxes_t, labels_t, landms_t, pad_image_flag
+ return image, boxes, labels, landm, pad_image_flag
+
+
+def _distort(image):
+
+ def _convert(image, alpha=1, beta=0):
+ tmp = image.astype(float) * alpha + beta
+ tmp[tmp < 0] = 0
+ tmp[tmp > 255] = 255
+ image[:] = tmp
+
+ image = image.copy()
+
+ if random.randrange(2):
+
+ #brightness distortion
+ if random.randrange(2):
+ _convert(image, beta=random.uniform(-32, 32))
+
+ #contrast distortion
+ if random.randrange(2):
+ _convert(image, alpha=random.uniform(0.5, 1.5))
+
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+ #saturation distortion
+ if random.randrange(2):
+ _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
+
+ #hue distortion
+ if random.randrange(2):
+ tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
+ tmp %= 180
+ image[:, :, 0] = tmp
+
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
+
+ else:
+
+ #brightness distortion
+ if random.randrange(2):
+ _convert(image, beta=random.uniform(-32, 32))
+
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+ #saturation distortion
+ if random.randrange(2):
+ _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
+
+ #hue distortion
+ if random.randrange(2):
+ tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
+ tmp %= 180
+ image[:, :, 0] = tmp
+
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
+
+ #contrast distortion
+ if random.randrange(2):
+ _convert(image, alpha=random.uniform(0.5, 1.5))
+
+ return image
+
+
+def _expand(image, boxes, fill, p):
+ if random.randrange(2):
+ return image, boxes
+
+ height, width, depth = image.shape
+
+ scale = random.uniform(1, p)
+ w = int(scale * width)
+ h = int(scale * height)
+
+ left = random.randint(0, w - width)
+ top = random.randint(0, h - height)
+
+ boxes_t = boxes.copy()
+ boxes_t[:, :2] += (left, top)
+ boxes_t[:, 2:] += (left, top)
+ expand_image = np.empty(
+ (h, w, depth),
+ dtype=image.dtype)
+ expand_image[:, :] = fill
+ expand_image[top:top + height, left:left + width] = image
+ image = expand_image
+
+ return image, boxes_t
+
+
+def _mirror(image, boxes, landms):
+ _, width, _ = image.shape
+ if random.randrange(2):
+ image = image[:, ::-1]
+ boxes = boxes.copy()
+ boxes[:, 0::2] = width - boxes[:, 2::-2]
+
+ # landm
+ landms = landms.copy()
+ landms = landms.reshape([-1, 5, 2])
+ landms[:, :, 0] = width - landms[:, :, 0]
+ tmp = landms[:, 1, :].copy()
+ landms[:, 1, :] = landms[:, 0, :]
+ landms[:, 0, :] = tmp
+ tmp1 = landms[:, 4, :].copy()
+ landms[:, 4, :] = landms[:, 3, :]
+ landms[:, 3, :] = tmp1
+ landms = landms.reshape([-1, 10])
+
+ return image, boxes, landms
+
+
+def _pad_to_square(image, rgb_mean, pad_image_flag):
+ if not pad_image_flag:
+ return image
+ height, width, _ = image.shape
+ long_side = max(width, height)
+ image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
+ image_t[:, :] = rgb_mean
+ image_t[0:0 + height, 0:0 + width] = image
+ return image_t
+
+
+def _resize_subtract_mean(image, insize, rgb_mean):
+ interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
+ interp_method = interp_methods[random.randrange(5)]
+ image = cv2.resize(image, (insize, insize), interpolation=interp_method)
+ image = image.astype(np.float32)
+ image -= rgb_mean
+ return image.transpose(2, 0, 1)
+
+
+class preproc(object):
+
+ def __init__(self, img_dim, rgb_means):
+ self.img_dim = img_dim
+ self.rgb_means = rgb_means
+
+ def __call__(self, image, targets):
+ assert targets.shape[0] > 0, "this image does not have gt"
+
+ boxes = targets[:, :4].copy()
+ labels = targets[:, -1].copy()
+ landm = targets[:, 4:-1].copy()
+
+ image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
+ image_t = _distort(image_t)
+ image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
+ image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
+ height, width, _ = image_t.shape
+ image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
+ boxes_t[:, 0::2] /= width
+ boxes_t[:, 1::2] /= height
+
+ landm_t[:, 0::2] /= width
+ landm_t[:, 1::2] /= height
+
+ labels_t = np.expand_dims(labels_t, 1)
+ targets_t = np.hstack((boxes_t, landm_t, labels_t))
+
+ return image_t, targets_t
diff --git a/retinaface/data/wider_face.py b/retinaface/data/wider_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f56efdc221bd4162d22884669ba44a3d4de5cd
--- /dev/null
+++ b/retinaface/data/wider_face.py
@@ -0,0 +1,101 @@
+import os
+import os.path
+import sys
+import torch
+import torch.utils.data as data
+import cv2
+import numpy as np
+
+class WiderFaceDetection(data.Dataset):
+ def __init__(self, txt_path, preproc=None):
+ self.preproc = preproc
+ self.imgs_path = []
+ self.words = []
+ f = open(txt_path,'r')
+ lines = f.readlines()
+ isFirst = True
+ labels = []
+ for line in lines:
+ line = line.rstrip()
+ if line.startswith('#'):
+ if isFirst is True:
+ isFirst = False
+ else:
+ labels_copy = labels.copy()
+ self.words.append(labels_copy)
+ labels.clear()
+ path = line[2:]
+ path = txt_path.replace('label.txt','images/') + path
+ self.imgs_path.append(path)
+ else:
+ line = line.split(' ')
+ label = [float(x) for x in line]
+ labels.append(label)
+
+ self.words.append(labels)
+
+ def __len__(self):
+ return len(self.imgs_path)
+
+ def __getitem__(self, index):
+ img = cv2.imread(self.imgs_path[index])
+ height, width, _ = img.shape
+
+ labels = self.words[index]
+ annotations = np.zeros((0, 15))
+ if len(labels) == 0:
+ return annotations
+ for idx, label in enumerate(labels):
+ annotation = np.zeros((1, 15))
+ # bbox
+ annotation[0, 0] = label[0] # x1
+ annotation[0, 1] = label[1] # y1
+ annotation[0, 2] = label[0] + label[2] # x2
+ annotation[0, 3] = label[1] + label[3] # y2
+
+ # landmarks
+ annotation[0, 4] = label[4] # l0_x
+ annotation[0, 5] = label[5] # l0_y
+ annotation[0, 6] = label[7] # l1_x
+ annotation[0, 7] = label[8] # l1_y
+ annotation[0, 8] = label[10] # l2_x
+ annotation[0, 9] = label[11] # l2_y
+ annotation[0, 10] = label[13] # l3_x
+ annotation[0, 11] = label[14] # l3_y
+ annotation[0, 12] = label[16] # l4_x
+ annotation[0, 13] = label[17] # l4_y
+ if (annotation[0, 4]<0):
+ annotation[0, 14] = -1
+ else:
+ annotation[0, 14] = 1
+
+ annotations = np.append(annotations, annotation, axis=0)
+ target = np.array(annotations)
+ if self.preproc is not None:
+ img, target = self.preproc(img, target)
+
+ return torch.from_numpy(img), target
+
+def detection_collate(batch):
+ """Custom collate fn for dealing with batches of images that have a different
+ number of associated object annotations (bounding boxes).
+
+ Arguments:
+ batch: (tuple) A tuple of tensor images and lists of annotations
+
+ Return:
+ A tuple containing:
+ 1) (tensor) batch of images stacked on their 0 dim
+ 2) (list of tensors) annotations for a given image are stacked on 0 dim
+ """
+ targets = []
+ imgs = []
+ for _, sample in enumerate(batch):
+ for _, tup in enumerate(sample):
+ if torch.is_tensor(tup):
+ imgs.append(tup)
+ elif isinstance(tup, type(np.empty(0))):
+ annos = torch.from_numpy(tup).float()
+ targets.append(annos)
+
+ return (torch.stack(imgs, 0), targets)
diff --git a/retinaface/detect.py b/retinaface/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc622e4deef728d1153254bfa6cf0d4ef50352aa
--- /dev/null
+++ b/retinaface/detect.py
@@ -0,0 +1,152 @@
+import torch
+import torch.nn.functional as F
+import cv2
+import numpy as np
+import timeit
+
+import imutils
+from utils.infer_utils import load_model
+from data import cfg_mnet as cfg
+from models.retinaface import RetinaFace
+from layers.functions.prior_box import PriorBox
+from utils.box_utils import decode, decode_landm
+from utils.nms.py_cpu_nms import py_cpu_nms
+from utils.infer_utils import align_face
+torch.set_grad_enabled(False)
+
+
+class Detect:
+ def __init__(self, weight_path, net_inshape=(180, 320)):
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.net_inshape = net_inshape
+ im_height, im_width = net_inshape
+ self.box_scale = np.array([im_width, im_height] * 2)
+ self.lmk_scale = np.array([im_width, im_height] * 5)
+
+ priorbox = PriorBox(cfg, image_size=net_inshape)
+ priors = priorbox.forward()
+ self.prior_data = priors.to(self.device)
+ self.net = RetinaFace(cfg=cfg, phase='test')
+ self.net = load_model(self.net, weight_path, False)
+ self.net.eval()
+ self.net = self.net.to(self.device)
+
+
+ def _preprocess(self, image):
+ rgb_mean = (104, 117, 123) # bgr order
+ h, w = image.shape[:2]
+ dx = int(self.net_inshape[1] * h / self.net_inshape[0] - w)
+ dy = 0
+ if dx < 0:
+ dx = 0
+ dy = int(self.net_inshape[0] * w / self.net_inshape[1] - h)
+ img = cv2.copyMakeBorder(image, 0, dy, 0, dx, borderType=cv2.BORDER_CONSTANT, value=rgb_mean)
+ img = cv2.copyMakeBorder(img, 0, img.shape[0], 0, img.shape[1], borderType=cv2.BORDER_CONSTANT, value=rgb_mean)
+
+ h, w = img.shape[:2]
+ resize = float(self.net_inshape[1]) / float(w)
+ img = cv2.resize(img, self.net_inshape[::-1])
+ img = np.float32(img)
+ img -= rgb_mean
+
+ return img, resize
+
+
+ def __call__(self, img, verbose=False):
+ '''
+ bgr image
+ '''
+ t0 = timeit.default_timer()
+ img, resize = self._preprocess(img)
+ img = img.transpose(2, 0, 1)
+ img = torch.from_numpy(img).unsqueeze(0)
+ img = img.to(self.device)
+
+ t1 = timeit.default_timer()
+ loc, conf, landms = self.net(img)
+ loc = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 4) for i in loc]
+ loc = torch.cat(loc, dim=1)
+ conf = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 3) for i in conf]
+ conf = torch.cat(conf, dim=1)
+ landms = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 10) for i in landms]
+ landms = torch.cat(landms, dim=1)
+ conf = F.softmax(conf, dim=-1)
+
+ t2 = timeit.default_timer()
+ conf = conf[0]
+ scores = conf.squeeze(0).detach().cpu().numpy()[:, 1:]
+ scores = np.amax(scores, axis=1)
+
+ boxes = decode(loc[0], self.prior_data, cfg['variance']) # loc[0]
+ boxes = boxes.detach().cpu().numpy()
+ boxes = boxes * self.box_scale / resize
+
+ landms = decode_landm(landms[0], self.prior_data, cfg['variance'])
+ landms = landms.detach().cpu().numpy()
+ landms = landms * self.lmk_scale / resize
+
+ # ignore low scores
+ inds = np.where(scores > 0.02)[0]
+ boxes = boxes[inds]
+ landms = landms[inds]
+ scores = scores[inds]
+
+ # keep top-K before NMS
+ order = scores.argsort()[::-1][:5000]
+ boxes = boxes[order]
+ landms = landms[order]
+ scores = scores[order]
+
+ # do NMS
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = py_cpu_nms(dets, 0.4)
+ # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
+ dets = dets[keep, :]
+ landms = landms[keep]
+
+ # keep top-K faster NMS
+ dets = dets[:750, :]
+ landms = landms[:750, :]
+
+ dets = np.concatenate((dets, landms), axis=1)
+ dets = dets[dets[:, 4] > 0.5]
+ dets = dets[np.argsort(dets, axis=0)[:, 0]]
+
+ t3 = timeit.default_timer()
+ if verbose:
+ print(t1 - t0, t2 - t1, t3 - t2)
+
+ return dets # (n, 15), box=0-3, cls=4, lmk=5-10
+
+
+if __name__ == "__main__":
+ net_inshape = (486, 864) # h, w
+ model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape)
+ image_path = "/mnt/nvme0n1p2/datasets/face/dyno/mytelpay230626/mytelpay230626_raw/data_2nd/၁၀ကတန(နိုင်)၀၀၁၀၀၁/mytel_ekyc_1m2_65160cc1802b6183d87fca091cab4c2faa93a9b1614106b5911ca778_front_image.jpg"
+ img = cv2.imread(image_path)
+
+ dets = model(img)
+ for i, b in enumerate(dets):
+ text = "{:.4f}".format(b[4])
+ b = b.astype(np.int32)
+ landm = b[5:15]
+ landm = landm.reshape((5, 2))
+
+ alighed_face = align_face(img, landm.copy())
+ # cv2.imshow(str(i), alighed_face)
+
+ # landms
+ landm = landm.astype(np.int32)
+ cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
+ cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
+ cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
+ cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
+ cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
+
+ cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
+ cx = b[0]
+ cy = b[1] + 20
+ cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255))
+
+ cv2.imwrite("./output.jpg", img)
+
\ No newline at end of file
diff --git a/retinaface/detect_video_raw.py b/retinaface/detect_video_raw.py
new file mode 100644
index 0000000000000000000000000000000000000000..231a7932b639655d2b38ed3f8f814e249d993c07
--- /dev/null
+++ b/retinaface/detect_video_raw.py
@@ -0,0 +1,66 @@
+import cv2
+import numpy as np
+import imutils
+
+from utils.fps import FPS
+from utils.infer_utils import LoadStream, align_face
+from detect import Detect
+
+
+net_inshape = (486, 864) # h, w
+model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape)
+# dataloader = LoadStream("rtsp://admin:meditech123@192.168.100.90:555/")
+dataloader = LoadStream("../30Shine_1.mp4")
+fps = FPS().start()
+
+for frame in dataloader:
+ # frame = imutils.resize(frame, width=640)
+ frame = frame.copy()
+ frame_raw = frame.copy()
+
+
+ dets = model(frame)
+ for i, b in enumerate(dets):
+ text = "{:.4f}".format(b[4])
+ b = b.astype(np.int32)
+ landm = b[5:15]
+ landm = landm.reshape((5, 2))
+
+ alighed_face = align_face(frame, landm.copy())
+ # cv2.imshow(str(i), alighed_face)
+
+ # landms
+ landm = landm.astype(np.int32)
+ cv2.circle(frame, tuple(landm[0]), 1, (0, 0, 255), 2)
+ cv2.circle(frame, tuple(landm[1]), 1, (0, 255, 255), 2)
+ cv2.circle(frame, tuple(landm[2]), 1, (255, 0, 255), 2)
+ cv2.circle(frame, tuple(landm[3]), 1, (0, 255, 0), 2)
+ cv2.circle(frame, tuple(landm[4]), 1, (255, 0, 0), 2)
+
+ cv2.rectangle(frame, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
+ cx = b[0]
+ cy = b[1] + 20
+ cv2.putText(frame, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255))
+
+ fps.update()
+ text_fps = "FPS: {:.3f}".format(fps.get_fps_n())
+ cv2.putText(frame, text_fps, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
+ cv2.imshow("frame", imutils.resize(frame, width=1700))
+ key = cv2.waitKey(1) & 0xff
+ if key == ord("q"):
+ break
+ elif key == ord("c"):
+ while True:
+ cv2.imshow("frame", imutils.resize(frame, width=1700))
+ key = cv2.waitKey(1) & 0xff
+ if key == ord("q"):
+ break
+ # cv2.imwrite(f"{i}.jpg", alighed_face)
+ # i += 1
+ # # break
+
+print(text_fps)
+cv2.destroyAllWindows()
+fps.stop()
+print("Total FPS: {}".format(fps.fps()))
+dataloader.close()
\ No newline at end of file
diff --git a/retinaface/layers/__init__.py b/retinaface/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53a3f4b5160995d93bc7911e808b3045d74362c9
--- /dev/null
+++ b/retinaface/layers/__init__.py
@@ -0,0 +1,2 @@
+from .functions import *
+from .modules import *
diff --git a/retinaface/layers/functions/prior_box.py b/retinaface/layers/functions/prior_box.py
new file mode 100644
index 0000000000000000000000000000000000000000..683156a298d6f9440be64c611988b04aea97ec49
--- /dev/null
+++ b/retinaface/layers/functions/prior_box.py
@@ -0,0 +1,33 @@
+import torch
+from itertools import product as product
+import numpy as np
+from math import ceil
+
+
+class PriorBox(object):
+ def __init__(self, cfg, image_size=None, phase='train'):
+ super(PriorBox, self).__init__()
+ self.min_sizes = cfg['min_sizes']
+ self.steps = cfg['steps']
+ self.clip = cfg['clip']
+ self.image_size = image_size
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
+
+ def forward(self):
+ anchors = []
+ for k, f in enumerate(self.feature_maps):
+ min_sizes = self.min_sizes[k]
+ for i, j in product(range(f[0]), range(f[1])):
+ for min_size in min_sizes:
+ s_kx = min_size / self.image_size[1]
+ s_ky = min_size / self.image_size[0]
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+
+ # back to torch land
+ output = torch.Tensor(anchors).view(-1, 4)
+ if self.clip:
+ output.clamp_(max=1, min=0)
+ return output
diff --git a/retinaface/layers/modules/__init__.py b/retinaface/layers/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf24bddbf283f233d0b93fc074a2bac2f5c044a9
--- /dev/null
+++ b/retinaface/layers/modules/__init__.py
@@ -0,0 +1,3 @@
+from .multibox_loss import MultiBoxLoss
+
+__all__ = ['MultiBoxLoss']
diff --git a/retinaface/layers/modules/multibox_loss.py b/retinaface/layers/modules/multibox_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..096620480eba59e9d893c1940899f7e3d6736cae
--- /dev/null
+++ b/retinaface/layers/modules/multibox_loss.py
@@ -0,0 +1,125 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from utils.box_utils import match, log_sum_exp
+from data import cfg_mnet
+GPU = cfg_mnet['gpu_train']
+
+class MultiBoxLoss(nn.Module):
+ """SSD Weighted Loss Function
+ Compute Targets:
+ 1) Produce Confidence Target Indices by matching ground truth boxes
+ with (default) 'priorboxes' that have jaccard index > threshold parameter
+ (default threshold: 0.5).
+ 2) Produce localization target by 'encoding' variance into offsets of ground
+ truth boxes and their matched 'priorboxes'.
+ 3) Hard negative mining to filter the excessive number of negative examples
+ that comes with using a large number of default bounding boxes.
+ (default negative:positive ratio 3:1)
+ Objective Loss:
+ L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
+ Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
+ weighted by α which is set to 1 by cross val.
+ Args:
+ c: class confidences,
+ l: predicted boxes,
+ g: ground truth boxes
+ N: number of matched default boxes
+ See: https://arxiv.org/pdf/1512.02325.pdf for more details.
+ """
+
+ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
+ super(MultiBoxLoss, self).__init__()
+ self.num_classes = num_classes
+ self.threshold = overlap_thresh
+ self.background_label = bkg_label
+ self.encode_target = encode_target
+ self.use_prior_for_matching = prior_for_matching
+ self.do_neg_mining = neg_mining
+ self.negpos_ratio = neg_pos
+ self.neg_overlap = neg_overlap
+ self.variance = [0.1, 0.2]
+
+ def forward(self, predictions, priors, targets):
+ """Multibox Loss
+ Args:
+ predictions (tuple): A tuple containing loc preds, conf preds,
+ and prior boxes from SSD net.
+ conf shape: torch.size(batch_size,num_priors,num_classes)
+ loc shape: torch.size(batch_size,num_priors,4)
+ priors shape: torch.size(num_priors,4)
+
+ ground_truth (tensor): Ground truth boxes and labels for a batch,
+ shape: [batch_size,num_objs,5] (last idx is the label).
+ """
+
+ loc_data, conf_data, landm_data = predictions
+ priors = priors
+ num = loc_data.size(0)
+ num_priors = (priors.size(0))
+
+ # match priors (default boxes) and ground truth boxes
+ loc_t = torch.Tensor(num, num_priors, 4)
+ landm_t = torch.Tensor(num, num_priors, 10)
+ conf_t = torch.LongTensor(num, num_priors)
+ for idx in range(num):
+ truths = targets[idx][:, :4].data
+ labels = targets[idx][:, -1].data
+ landms = targets[idx][:, 4:14].data
+ defaults = priors.data
+ match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
+ if GPU:
+ loc_t = loc_t.cuda()
+ conf_t = conf_t.cuda()
+ landm_t = landm_t.cuda()
+
+ zeros = torch.tensor(0).cuda()
+ # landm Loss (Smooth L1)
+ # Shape: [batch,num_priors,10]
+ pos1 = conf_t > zeros
+ num_pos_landm = pos1.long().sum(1, keepdim=True)
+ N1 = max(num_pos_landm.data.sum().float(), 1)
+ pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
+ landm_p = landm_data[pos_idx1].view(-1, 10)
+ landm_t = landm_t[pos_idx1].view(-1, 10)
+ loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
+
+
+ pos = conf_t != zeros
+ conf_t[pos] = 1
+
+ # Localization Loss (Smooth L1)
+ # Shape: [batch,num_priors,4]
+ pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
+ loc_p = loc_data[pos_idx].view(-1, 4)
+ loc_t = loc_t[pos_idx].view(-1, 4)
+ loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
+
+ # Compute max conf across batch for hard negative mining
+ batch_conf = conf_data.view(-1, self.num_classes)
+ loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
+
+ # Hard Negative Mining
+ loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
+ loss_c = loss_c.view(num, -1)
+ _, loss_idx = loss_c.sort(1, descending=True)
+ _, idx_rank = loss_idx.sort(1)
+ num_pos = pos.long().sum(1, keepdim=True)
+ num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
+ neg = idx_rank < num_neg.expand_as(idx_rank)
+
+ # Confidence Loss Including Positive and Negative Examples
+ pos_idx = pos.unsqueeze(2).expand_as(conf_data)
+ neg_idx = neg.unsqueeze(2).expand_as(conf_data)
+ conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
+ targets_weighted = conf_t[(pos+neg).gt(0)]
+ loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
+
+ # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
+ N = max(num_pos.data.sum().float(), 1)
+ loss_l /= N
+ loss_c /= N
+ loss_landm /= N1
+
+ return loss_l, loss_c, loss_landm
diff --git a/retinaface/models/__init__.py b/retinaface/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/retinaface/models/net.py b/retinaface/models/net.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d769cc2f0a3d3e85171d79ae3ea9396606b53d9
--- /dev/null
+++ b/retinaface/models/net.py
@@ -0,0 +1,137 @@
+import time
+import torch
+import torch.nn as nn
+import torchvision.models._utils as _utils
+import torchvision.models as models
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+def conv_bn(inp, oup, stride = 1):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+def conv_bn_no_relu(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+def conv_bn1X1(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+def conv_dw(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.ReLU(inplace=True),
+
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+class SSH(nn.Module):
+ def __init__(self, in_channel, out_channel):
+ super(SSH, self).__init__()
+ assert out_channel % 4 == 0
+ leaky = 0
+ if (out_channel <= 64):
+ leaky = 0.1
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
+
+ self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1)
+ self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
+
+ self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1)
+ self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
+
+ def forward(self, input):
+ conv3X3 = self.conv3X3(input)
+
+ conv5X5_1 = self.conv5X5_1(input)
+ conv5X5 = self.conv5X5_2(conv5X5_1)
+
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
+ conv7X7 = self.conv7x7_3(conv7X7_2)
+
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+ out = F.relu(out)
+ return out
+
+class FPN(nn.Module):
+ def __init__(self,in_channels_list,out_channels):
+ super(FPN,self).__init__()
+ leaky = 0
+ if (out_channels <= 64):
+ leaky = 0.1
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1)
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1)
+
+ self.merge1 = conv_bn(out_channels, out_channels)
+ self.merge2 = conv_bn(out_channels, out_channels)
+
+ def forward(self, input):
+ # names = list(input.keys())
+ input = list(input.values())
+
+ output1 = self.output1(input[0])
+ output2 = self.output2(input[1])
+ output3 = self.output3(input[2])
+
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
+ output2 = output2 + up3
+ output2 = self.merge2(output2)
+
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
+ output1 = output1 + up2
+ output1 = self.merge1(output1)
+
+ out = [output1, output2, output3]
+ return out
+
+
+
+class MobileNetV1(nn.Module):
+ def __init__(self):
+ super(MobileNetV1, self).__init__()
+ self.stage1 = nn.Sequential(
+ conv_bn(3, 8, 2), # 3
+ conv_dw(8, 16, 1), # 7
+ conv_dw(16, 32, 2), # 11
+ conv_dw(32, 32, 1), # 19
+ conv_dw(32, 64, 2), # 27
+ conv_dw(64, 64, 1), # 43
+ )
+ self.stage2 = nn.Sequential(
+ conv_dw(64, 128, 2), # 43 + 16 = 59
+ conv_dw(128, 128, 1), # 59 + 32 = 91
+ conv_dw(128, 128, 1), # 91 + 32 = 123
+ conv_dw(128, 128, 1), # 123 + 32 = 155
+ conv_dw(128, 128, 1), # 155 + 32 = 187
+ conv_dw(128, 128, 1), # 187 + 32 = 219
+ )
+ self.stage3 = nn.Sequential(
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
+ conv_dw(256, 256, 1), # 241 + 64 = 301
+ )
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
+ self.fc = nn.Linear(256, 1000)
+
+ def forward(self, x):
+ x = self.stage1(x)
+ x = self.stage2(x)
+ x = self.stage3(x)
+ x = self.avg(x)
+ # x = self.model(x)
+ x = x.view(-1, 256)
+ x = self.fc(x)
+ return x
+
diff --git a/retinaface/models/net_rfb.py b/retinaface/models/net_rfb.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b79e959c1133a123e4a343b27ce406c859ebf4b
--- /dev/null
+++ b/retinaface/models/net_rfb.py
@@ -0,0 +1,199 @@
+import torch
+import torch.nn as nn
+import torchvision.models.detection.backbone_utils as backbone_utils
+import torchvision.models._utils as _utils
+import torch.nn.functional as F
+from collections import OrderedDict
+
+class BasicConv(nn.Module):
+
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True):
+ super(BasicConv, self).__init__()
+ self.out_channels = out_planes
+ if bn:
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
+ self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
+ self.relu = nn.ReLU(inplace=True) if relu else None
+ else:
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
+ self.bn = None
+ self.relu = nn.ReLU(inplace=True) if relu else None
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ if self.relu is not None:
+ x = self.relu(x)
+ return x
+
+
+class BasicRFB(nn.Module):
+
+ def __init__(self, in_planes, out_planes, stride=1, scale=0.1, map_reduce=8, vision=1, groups=1):
+ super(BasicRFB, self).__init__()
+ self.scale = scale
+ self.out_channels = out_planes
+ inter_planes = in_planes // map_reduce
+
+ self.branch0 = nn.Sequential(
+ BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
+ BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
+ BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 1, dilation=vision + 1, relu=False, groups=groups)
+ )
+ self.branch1 = nn.Sequential(
+ BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
+ BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
+ BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 2, dilation=vision + 2, relu=False, groups=groups)
+ )
+ self.branch2 = nn.Sequential(
+ BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
+ BasicConv(inter_planes, (inter_planes // 2) * 3, kernel_size=3, stride=1, padding=1, groups=groups),
+ BasicConv((inter_planes // 2) * 3, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
+ BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 4, dilation=vision + 4, relu=False, groups=groups)
+ )
+
+ self.ConvLinear = BasicConv(6 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
+ self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+
+ out = torch.cat((x0, x1, x2), 1)
+ out = self.ConvLinear(out)
+ short = self.shortcut(x)
+ out = out * self.scale + short
+ out = self.relu(out)
+
+ return out
+
+
+
+def conv_bn(inp, oup, stride = 1):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+def depth_conv2d(inp, oup, kernel=1, stride=1, pad=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, kernel_size = kernel, stride = stride, padding=pad, groups=inp),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inp, oup, kernel_size=1)
+ )
+
+def conv_dw(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.ReLU(inplace=True),
+
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+class RFB(nn.Module):
+ def __init__(self, cfg = None, phase = 'train'):
+ """
+ :param cfg: Network related settings.
+ :param phase: train or test.
+ """
+ super(RFB,self).__init__()
+ self.phase = phase
+ self.num_classes = 2
+
+ self.conv1 = conv_bn(3, 16, 2)
+ self.conv2 = conv_dw(16, 32, 1)
+ self.conv3 = conv_dw(32, 32, 2)
+ self.conv4 = conv_dw(32, 32, 1)
+ self.conv5 = conv_dw(32, 64, 2)
+ self.conv6 = conv_dw(64, 64, 1)
+ self.conv7 = conv_dw(64, 64, 1)
+ self.conv8 = BasicRFB(64, 64, stride=1, scale=1.0)
+
+ self.conv9 = conv_dw(64, 128, 2)
+ self.conv10 = conv_dw(128, 128, 1)
+ self.conv11 = conv_dw(128, 128, 1)
+
+ self.conv12 = conv_dw(128, 256, 2)
+ self.conv13 = conv_dw(256, 256, 1)
+
+ self.conv14 = nn.Sequential(
+ nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1),
+ nn.ReLU(inplace=True),
+ depth_conv2d(64, 256, kernel=3, stride=2, pad=1),
+ nn.ReLU(inplace=True)
+ )
+ self.loc, self.conf, self.landm = self.multibox(self.num_classes);
+
+ def multibox(self, num_classes):
+ loc_layers = []
+ conf_layers = []
+ landm_layers = []
+ loc_layers += [depth_conv2d(64, 3 * 4, kernel=3, pad=1)]
+ conf_layers += [depth_conv2d(64, 3 * num_classes, kernel=3, pad=1)]
+ landm_layers += [depth_conv2d(64, 3 * 10, kernel=3, pad=1)]
+
+ loc_layers += [depth_conv2d(128, 2 * 4, kernel=3, pad=1)]
+ conf_layers += [depth_conv2d(128, 2 * num_classes, kernel=3, pad=1)]
+ landm_layers += [depth_conv2d(128, 2 * 10, kernel=3, pad=1)]
+
+ loc_layers += [depth_conv2d(256, 2 * 4, kernel=3, pad=1)]
+ conf_layers += [depth_conv2d(256, 2 * num_classes, kernel=3, pad=1)]
+ landm_layers += [depth_conv2d(256, 2 * 10, kernel=3, pad=1)]
+
+ loc_layers += [nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)]
+ landm_layers += [nn.Conv2d(256, 3 * 10, kernel_size=3, padding=1)]
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers), nn.Sequential(*landm_layers)
+
+
+ def forward(self,inputs):
+ detections = list()
+ loc = list()
+ conf = list()
+ landm = list()
+
+ x1 = self.conv1(inputs)
+ x2 = self.conv2(x1)
+ x3 = self.conv3(x2)
+ x4 = self.conv4(x3)
+ x5 = self.conv5(x4)
+ x6 = self.conv6(x5)
+ x7 = self.conv7(x6)
+ x8 = self.conv8(x7)
+ detections.append(x8)
+
+ x9 = self.conv9(x8)
+ x10 = self.conv10(x9)
+ x11 = self.conv11(x10)
+ detections.append(x11)
+
+ x12 = self.conv12(x11)
+ x13 = self.conv13(x12)
+ detections.append(x13)
+
+ x14= self.conv14(x13)
+ detections.append(x14)
+
+ for (x, l, c, lam) in zip(detections, self.loc, self.conf, self.landm):
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
+ landm.append(lam(x).permute(0, 2, 3, 1).contiguous())
+
+ bbox_regressions = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1)
+ classifications = torch.cat([o.view(o.size(0), -1, 2) for o in conf], 1)
+ ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1)
+
+
+
+ if self.phase == 'train':
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+ return output
diff --git a/retinaface/models/net_slim.py b/retinaface/models/net_slim.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74910e269df2519001c95d9e99ad577c390e08d
--- /dev/null
+++ b/retinaface/models/net_slim.py
@@ -0,0 +1,132 @@
+import torch
+import torch.nn as nn
+import torchvision.models.detection.backbone_utils as backbone_utils
+import torchvision.models._utils as _utils
+import torch.nn.functional as F
+from collections import OrderedDict
+
+def conv_bn(inp, oup, stride = 1):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+def depth_conv2d(inp, oup, kernel=1, stride=1, pad=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, kernel_size = kernel, stride = stride, padding=pad, groups=inp),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inp, oup, kernel_size=1)
+ )
+
+def conv_dw(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.ReLU(inplace=True),
+
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+class Slim(nn.Module):
+ def __init__(self, cfg = None, phase = 'train'):
+ """
+ :param cfg: Network related settings.
+ :param phase: train or test.
+ """
+ super(Slim, self).__init__()
+ self.phase = phase
+ self.num_classes = 2
+
+ self.conv1 = conv_bn(3, 16, 2)
+ self.conv2 = conv_dw(16, 32, 1)
+ self.conv3 = conv_dw(32, 32, 2)
+ self.conv4 = conv_dw(32, 32, 1)
+ self.conv5 = conv_dw(32, 64, 2)
+ self.conv6 = conv_dw(64, 64, 1)
+ self.conv7 = conv_dw(64, 64, 1)
+ self.conv8 = conv_dw(64, 64, 1)
+
+ self.conv9 = conv_dw(64, 128, 2)
+ self.conv10 = conv_dw(128, 128, 1)
+ self.conv11 = conv_dw(128, 128, 1)
+
+ self.conv12 = conv_dw(128, 256, 2)
+ self.conv13 = conv_dw(256, 256, 1)
+
+ self.conv14 = nn.Sequential(
+ nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1),
+ nn.ReLU(inplace=True),
+ depth_conv2d(64, 256, kernel=3, stride=2, pad=1),
+ nn.ReLU(inplace=True)
+ )
+ self.loc, self.conf, self.landm = self.multibox(self.num_classes);
+
+ def multibox(self, num_classes):
+ loc_layers = []
+ conf_layers = []
+ landm_layers = []
+ loc_layers += [depth_conv2d(64, 3 * 4, kernel=3, pad=1)]
+ conf_layers += [depth_conv2d(64, 3 * num_classes, kernel=3, pad=1)]
+ landm_layers += [depth_conv2d(64, 3 * 10, kernel=3, pad=1)]
+
+ loc_layers += [depth_conv2d(128, 2 * 4, kernel=3, pad=1)]
+ conf_layers += [depth_conv2d(128, 2 * num_classes, kernel=3, pad=1)]
+ landm_layers += [depth_conv2d(128, 2 * 10, kernel=3, pad=1)]
+
+ loc_layers += [depth_conv2d(256, 2 * 4, kernel=3, pad=1)]
+ conf_layers += [depth_conv2d(256, 2 * num_classes, kernel=3, pad=1)]
+ landm_layers += [depth_conv2d(256, 2 * 10, kernel=3, pad=1)]
+
+ loc_layers += [nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)]
+ conf_layers += [nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)]
+ landm_layers += [nn.Conv2d(256, 3 * 10, kernel_size=3, padding=1)]
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers), nn.Sequential(*landm_layers)
+
+
+ def forward(self,inputs):
+ detections = list()
+ loc = list()
+ conf = list()
+ landm = list()
+
+ x1 = self.conv1(inputs)
+ x2 = self.conv2(x1)
+ x3 = self.conv3(x2)
+ x4 = self.conv4(x3)
+ x5 = self.conv5(x4)
+ x6 = self.conv6(x5)
+ x7 = self.conv7(x6)
+ x8 = self.conv8(x7)
+ detections.append(x8)
+
+ x9 = self.conv9(x8)
+ x10 = self.conv10(x9)
+ x11 = self.conv11(x10)
+ detections.append(x11)
+
+ x12 = self.conv12(x11)
+ x13 = self.conv13(x12)
+ detections.append(x13)
+
+ x14= self.conv14(x13)
+ detections.append(x14)
+
+ for (x, l, c, lam) in zip(detections, self.loc, self.conf, self.landm):
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
+ landm.append(lam(x).permute(0, 2, 3, 1).contiguous())
+
+ bbox_regressions = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1)
+ classifications = torch.cat([o.view(o.size(0), -1, 2) for o in conf], 1)
+ ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1)
+
+
+
+ if self.phase == 'train':
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+ return output
diff --git a/retinaface/models/retinaface.py b/retinaface/models/retinaface.py
new file mode 100644
index 0000000000000000000000000000000000000000..e229befc93d4552eb7c443ff6232b657b9ab325e
--- /dev/null
+++ b/retinaface/models/retinaface.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torchvision.models.detection.backbone_utils as backbone_utils
+import torchvision.models._utils as _utils
+import torch.nn.functional as F
+from collections import OrderedDict
+
+from models.net import MobileNetV1 as MobileNetV1
+from models.net import FPN as FPN
+from models.net import SSH as SSH
+
+
+
+class ClassHead(nn.Module):
+ def __init__(self,inchannels=512,num_anchors=3):
+ super(ClassHead,self).__init__()
+ self.num_anchors = num_anchors
+ self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*3,kernel_size=(1,1),stride=1,padding=0)
+
+ def forward(self,x):
+ out = self.conv1x1(x)
+ return out
+ # out = out.permute(0,2,3,1).contiguous()
+ # return out.view(out.shape[0], -1, 3)
+
+class BboxHead(nn.Module):
+ def __init__(self,inchannels=512,num_anchors=3):
+ super(BboxHead,self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
+
+ def forward(self,x):
+ out = self.conv1x1(x)
+ return out
+ # out = out.permute(0,2,3,1).contiguous()
+ # return out.view(out.shape[0], -1, 4)
+
+class LandmarkHead(nn.Module):
+ def __init__(self,inchannels=512,num_anchors=3):
+ super(LandmarkHead,self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
+
+ def forward(self,x):
+ out = self.conv1x1(x)
+ return out
+ # out = out.permute(0,2,3,1).contiguous()
+ # return out.view(out.shape[0], -1, 10)
+
+class RetinaFace(nn.Module):
+ def __init__(self, cfg = None, phase = 'train'):
+ """
+ :param cfg: Network related settings.
+ :param phase: train or test.
+ """
+ super(RetinaFace,self).__init__()
+ self.phase = phase
+ backbone = None
+ if cfg['name'] == 'mobilenet0.25':
+ backbone = MobileNetV1()
+ if cfg['pretrain']:
+ checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint['state_dict'].items():
+ name = k[7:] # remove module.
+ new_state_dict[name] = v
+ # load params
+ backbone.load_state_dict(new_state_dict)
+ elif cfg['name'] == 'Resnet50':
+ import torchvision.models as models
+ backbone = models.resnet50(pretrained=cfg['pretrain'])
+
+ self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
+ in_channels_stage2 = cfg['in_channel']
+ in_channels_list = [
+ in_channels_stage2 * 2,
+ in_channels_stage2 * 4,
+ in_channels_stage2 * 8,
+ ]
+ out_channels = cfg['out_channel']
+ self.fpn = FPN(in_channels_list,out_channels)
+ self.ssh1 = SSH(out_channels, out_channels)
+ self.ssh2 = SSH(out_channels, out_channels)
+ self.ssh3 = SSH(out_channels, out_channels)
+
+ self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
+ self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
+ self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
+
+ def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
+ classhead = nn.ModuleList()
+ for i in range(fpn_num):
+ classhead.append(ClassHead(inchannels,anchor_num))
+ return classhead
+
+ def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
+ bboxhead = nn.ModuleList()
+ for i in range(fpn_num):
+ bboxhead.append(BboxHead(inchannels,anchor_num))
+ return bboxhead
+
+ def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
+ landmarkhead = nn.ModuleList()
+ for i in range(fpn_num):
+ landmarkhead.append(LandmarkHead(inchannels,anchor_num))
+ return landmarkhead
+
+ def forward(self,inputs):
+ out = self.body(inputs)
+
+ # FPN
+ fpn = self.fpn(out)
+
+ # SSH
+ feature1 = self.ssh1(fpn[0])
+ feature2 = self.ssh2(fpn[1])
+ feature3 = self.ssh3(fpn[2])
+ features = [feature1, feature2, feature3]
+
+ # return features
+
+ bbox_regressions = [self.BboxHead[i](feature) for i, feature in enumerate(features)]
+ classifications = [self.ClassHead[i](feature) for i, feature in enumerate(features)]
+ ldm_regressions = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
+
+ if self.phase == 'train':
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ # bbox_regressions = decode_batch(bbox_regressions, self.prior_data, self.variance)
+ # ldm_regressions = decode_landm_batch(ldm_regressions, self.prior_data, self.variance)
+ output = (bbox_regressions, classifications, ldm_regressions)
+ return output
\ No newline at end of file
diff --git a/retinaface/output.jpg b/retinaface/output.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e3638f1d15adc20cea13ed8393c7db5418706252
Binary files /dev/null and b/retinaface/output.jpg differ
diff --git a/retinaface/utils/__init__.py b/retinaface/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/retinaface/utils/box_utils.py b/retinaface/utils/box_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56562f8dde8da2f39600f4fc8b646cf3b4991f27
--- /dev/null
+++ b/retinaface/utils/box_utils.py
@@ -0,0 +1,379 @@
+import torch
+import numpy as np
+
+
+def point_form(boxes):
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+ representation for comparison to point form ground truth data.
+ Args:
+ boxes: (tensor) center-size default boxes from priorbox layers.
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
+
+
+def center_size(boxes):
+ """ Convert prior_boxes to (cx, cy, w, h)
+ representation for comparison to center-size form ground truth data.
+ Args:
+ boxes: (tensor) point_form boxes
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
+
+
+def intersect(box_a, box_b):
+ """ We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.size(0)
+ B = box_b.size(0)
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+ inter = torch.clamp((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def matrix_iou(a, b):
+ """
+ return iou of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+ """
+ return iof of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+ """Match each prior box with the ground truth box of the highest jaccard
+ overlap, encode the bounding boxes, then return the matched indices
+ corresponding to both confidence and location preds.
+ Args:
+ threshold: (float) The overlap threshold used when mathing boxes.
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+ variances: (tensor) Variances corresponding to each prior coord,
+ Shape: [num_priors, 4].
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+ landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
+ idx: (int) current batch index
+ Return:
+ The matched indices corresponding to 1)location 2)confidence 3)landm preds.
+ """
+ # jaccard index
+ overlaps = jaccard(
+ truths,
+ point_form(priors)
+ )
+ # (Bipartite Matching)
+ # [1,num_objects] best prior for each ground truth
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+ # ignore hard gt
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+ if best_prior_idx_filter.shape[0] <= 0:
+ loc_t[idx] = 0
+ conf_t[idx] = 0
+ return
+
+ # [1,num_priors] best ground truth for each prior
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+ best_truth_idx.squeeze_(0)
+ best_truth_overlap.squeeze_(0)
+ best_prior_idx.squeeze_(1)
+ best_prior_idx_filter.squeeze_(1)
+ best_prior_overlap.squeeze_(1)
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
+ # TODO refactor: index best_prior_idx with long tensor
+ # ensure every gt matches with its prior of max overlap
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
+ best_truth_idx[best_prior_idx[j]] = j
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
+ loc = encode(matches, priors, variances)
+
+ matches_landm = landms[best_truth_idx]
+ landm = encode_landm(matches_landm, priors, variances)
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
+ conf_t[idx] = conf # [num_priors] top class label for each prior
+ landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+def encode_landm(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 10].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded landm (tensor), Shape: [num_priors, 10]
+ """
+
+ # dist b/t match center and prior's center
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
+ priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, :, 2:])
+ # g_cxcy /= priors[:, :, 2:]
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+ # return target for smooth_l1_loss
+ return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def decode_np(loc, priors, variances):
+ boxes = np.concatenate((
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+ ), dim=1)
+ return landms
+
+
+def decode_landm_np(pre, priors, variances):
+ landms = np.concatenate((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+ ), axis=1)
+ return landms
+
+
+def decode_batch(loc, priors, variances=[0.1, 0.2]):
+ priors = priors.expand(loc.shape[0], loc.shape[1], 4)
+ boxes = torch.cat((
+ priors[..., :2] + loc[..., :2] * variances[0] * priors[..., 2:],
+ priors[..., 2:] * torch.exp(loc[..., 2:] * variances[1])), 2)
+
+ # boxes = torch.cat((
+ # boxes[..., :2] - boxes[..., 2:] / 2,
+ # boxes[..., 2:]), 2
+ # )
+ # boxes = torch.cat((
+ # boxes[..., :2],
+ # boxes[..., 2:] + boxes[..., :2]), 2
+ # )
+ boxes[..., :2] -= boxes[..., 2:] / 2
+ boxes[..., 2:] += boxes[..., :2]
+ return boxes
+
+def decode_landm_batch(pre, priors, variances=[0.1, 0.2]):
+ priors = priors.expand(pre.shape[0], pre.shape[1], 4)
+ landms = torch.cat((priors[..., :2] + pre[..., :2] * variances[0] * priors[..., 2:],
+ priors[..., :2] + pre[..., 2:4] * variances[0] * priors[..., 2:],
+ priors[..., :2] + pre[..., 4:6] * variances[0] * priors[..., 2:],
+ priors[..., :2] + pre[..., 6:8] * variances[0] * priors[..., 2:],
+ priors[..., :2] + pre[..., 8:10] * variances[0] * priors[..., 2:],
+ ), dim=2)
+ return landms
+
+
+def log_sum_exp(x):
+ """Utility function for computing log_sum_exp while determining
+ This will be used to determine unaveraged confidence loss across
+ all examples in a batch.
+ Args:
+ x (Variable(tensor)): conf_preds from conf layers
+ """
+ x_max = x.data.max()
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+ """Apply non-maximum suppression at test time to avoid detecting too many
+ overlapping bounding boxes for a given object.
+ Args:
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+ top_k: (int) The Maximum number of box preds to consider.
+ Return:
+ The indices of the kept boxes with respect to num_priors.
+ """
+
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
+ if boxes.numel() == 0:
+ return keep
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ area = torch.mul(x2 - x1, y2 - y1)
+ v, idx = scores.sort(0) # sort in ascending order
+ # I = I[v >= 0.01]
+ idx = idx[-top_k:] # indices of the top-k largest vals
+ xx1 = boxes.new()
+ yy1 = boxes.new()
+ xx2 = boxes.new()
+ yy2 = boxes.new()
+ w = boxes.new()
+ h = boxes.new()
+
+ # keep = torch.Tensor()
+ count = 0
+ while idx.numel() > 0:
+ i = idx[-1] # index of current largest val
+ # keep.append(i)
+ keep[count] = i
+ count += 1
+ if idx.size(0) == 1:
+ break
+ idx = idx[:-1] # remove kept element from view
+ # load bboxes of next highest vals
+ torch.index_select(x1, 0, idx, out=xx1)
+ torch.index_select(y1, 0, idx, out=yy1)
+ torch.index_select(x2, 0, idx, out=xx2)
+ torch.index_select(y2, 0, idx, out=yy2)
+ # store element-wise max with next highest score
+ xx1 = torch.clamp(xx1, min=x1[i])
+ yy1 = torch.clamp(yy1, min=y1[i])
+ xx2 = torch.clamp(xx2, max=x2[i])
+ yy2 = torch.clamp(yy2, max=y2[i])
+ w.resize_as_(xx2)
+ h.resize_as_(yy2)
+ w = xx2 - xx1
+ h = yy2 - yy1
+ # check sizes of xx1 and xx2.. after each iteration
+ w = torch.clamp(w, min=0.0)
+ h = torch.clamp(h, min=0.0)
+ inter = w*h
+ # IoU = i / (area(a) + area(b) - i)
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
+ union = (rem_areas - inter) + area[i]
+ IoU = inter/union # store result in iou
+ # keep only elements with an IoU <= overlap
+ idx = idx[IoU.le(overlap)]
+ return keep, count
+
+
diff --git a/retinaface/utils/fps.py b/retinaface/utils/fps.py
new file mode 100644
index 0000000000000000000000000000000000000000..b48b37e4ccb8cd9883fdb8972cb123948642c901
--- /dev/null
+++ b/retinaface/utils/fps.py
@@ -0,0 +1,46 @@
+# import the necessary packages
+import datetime
+
+class FPS:
+ def __init__(self, nframes=20):
+ # store the start time, end time, and total number of frames
+ # that were examined between the start and end intervals
+ self._start = None
+ self._end = None
+ self._numFrames = 0
+
+ self._nframes = nframes
+ self._start_n = None
+ self.fps_n = 0.0
+
+ def start(self):
+ # start the timer
+ self._start = datetime.datetime.now()
+ return self
+
+ def stop(self):
+ # stop the timer
+ self._end = datetime.datetime.now()
+
+ def update(self):
+ # increment the total number of frames examined during the
+ # start and end intervals
+ if (self._numFrames) % self._nframes == 0:
+ self._start_n = datetime.datetime.now()
+
+ self._numFrames += 1
+
+ def elapsed(self):
+ # return the total number of seconds between the start and
+ # end interval
+ return (self._end - self._start).total_seconds()
+
+ def fps(self):
+ # compute the (approximate) frames per second
+ return self._numFrames / self.elapsed()
+
+ def get_fps_n(self):
+ # run after update
+ if (self._numFrames) % self._nframes == 0:
+ self.fps_n = self._nframes / ((datetime.datetime.now() - self._start_n).total_seconds())
+ return self.fps_n
\ No newline at end of file
diff --git a/retinaface/utils/infer_utils.py b/retinaface/utils/infer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f15c78ac61afb989cc1f0b25fdb1ddb97738098
--- /dev/null
+++ b/retinaface/utils/infer_utils.py
@@ -0,0 +1,148 @@
+import numpy as np
+import cv2
+# import tensorflow as tf
+# from tensorflow_serving.apis.predict_pb2 import PredictRequest
+# import grpc
+# from tensorflow_serving.apis import prediction_service_pb2_grpc
+# import time
+# from tensorflow.keras.models import model_from_json
+from imutils.paths import list_images
+import os
+import pickle
+import torch
+from skimage import transform
+import cv2
+from pathlib import Path
+from imutils.video import VideoStream, FileVideoStream
+
+
+def check_keys(model, pretrained_state_dict):
+ ckpt_keys = set(pretrained_state_dict.keys())
+ model_keys = set(model.state_dict().keys())
+ used_pretrained_keys = model_keys & ckpt_keys
+ unused_pretrained_keys = ckpt_keys - model_keys
+ missing_keys = model_keys - ckpt_keys
+ print('Missing keys:{}'.format(len(missing_keys)))
+ print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
+ print('Used keys:{}'.format(len(used_pretrained_keys)))
+ assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
+ return True
+
+
+def remove_prefix(state_dict, prefix):
+ ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
+ print('remove prefix \'{}\''.format(prefix))
+ f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
+ return {f(key): value for key, value in state_dict.items()}
+
+
+def load_model(model, pretrained_path, load_to_cpu):
+ print('Loading pretrained model from {}'.format(pretrained_path))
+ if load_to_cpu:
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
+ else:
+ device = torch.cuda.current_device()
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
+ if "state_dict" in pretrained_dict.keys():
+ pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
+ else:
+ pretrained_dict = remove_prefix(pretrained_dict, 'module.')
+ check_keys(model, pretrained_dict)
+ model.load_state_dict(pretrained_dict, strict=False)
+ return model
+
+
+def judge_side_face(facial_landmarks):
+ wide_dist = np.linalg.norm(facial_landmarks[0] - facial_landmarks[1])
+ high_dist = np.linalg.norm(facial_landmarks[0] - facial_landmarks[3])
+ dist_rate = high_dist / wide_dist
+
+ # cal std
+ vec_A = facial_landmarks[0] - facial_landmarks[2]
+ vec_B = facial_landmarks[1] - facial_landmarks[2]
+ vec_C = facial_landmarks[3] - facial_landmarks[2]
+ vec_D = facial_landmarks[4] - facial_landmarks[2]
+ dist_A = np.linalg.norm(vec_A)
+ dist_B = np.linalg.norm(vec_B)
+ dist_C = np.linalg.norm(vec_C)
+ dist_D = np.linalg.norm(vec_D)
+
+ # cal rate
+ high_rate = dist_A / dist_C
+ width_rate = dist_C / dist_D
+ high_ratio_variance = np.fabs(high_rate - 1.1) # smaller is better
+ width_ratio_variance = np.fabs(width_rate - 1)
+
+ if dist_rate < 1.3 and width_ratio_variance < 0.8:
+ return True, dist_rate, high_ratio_variance
+ return False, dist_rate, high_ratio_variance
+
+
+def align_face(cv_img, dst):
+ """align face theo widerface
+
+ Arguments:
+ cv_img {arr} -- Ảnh gốc
+ dst {arr}} -- landmark 5 điểm theo mtcnn
+
+ Returns:
+ arr -- Ảnh face đã align
+ """
+ face_img = np.zeros((112,112), dtype=np.uint8)
+ # Matrix standard lanmark same wider dataset
+ src = np.array([
+ [38.2946, 51.6963],
+ [73.5318, 51.5014],
+ [56.0252, 71.7366],
+ [41.5493, 92.3655],
+ [70.7299, 92.2041] ], dtype=np.float32)
+
+ tform = transform.SimilarityTransform()
+ tform.estimate(dst, src)
+ M = tform.params[0:2,:]
+ face_img = cv2.warpAffine(cv_img, M, (112,112), borderValue=0.0)
+ return face_img
+
+
+class LoadImages:
+ def __init__(self, dir):
+ self.image_paths = list(Path(dir).glob("*.jpg"))
+ self.index = -1
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ self.index += 1
+ if self.index >= len(self.image_paths):
+ self.index = -1
+ raise StopIteration
+ else:
+ image_path = str(self.image_paths[self.index])
+ image = cv2.imread(str(image_path))
+ return image
+
+ def close(self):
+ pass
+
+
+class LoadStream:
+ def __init__(self, path):
+ self.isfile = True
+ if "rtsp" in path:
+ self.stream = VideoStream(path).start()
+ self.isfile = False
+ else:
+ self.stream = FileVideoStream(path).start()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ image = self.stream.read()
+ if self.isfile and image is None:
+ raise StopIteration
+ return image
+
+ def close(self):
+ self.stream.stop()
diff --git a/retinaface/utils/nms/__init__.py b/retinaface/utils/nms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/retinaface/utils/nms/py_cpu_nms.py b/retinaface/utils/nms/py_cpu_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..54e7b25fef72b518df6dcf8d6fb78b986796c6e3
--- /dev/null
+++ b/retinaface/utils/nms/py_cpu_nms.py
@@ -0,0 +1,38 @@
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import numpy as np
+
+def py_cpu_nms(dets, thresh):
+ """Pure Python NMS baseline."""
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
diff --git a/retinaface/utils/timer.py b/retinaface/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4b3b8098a5ad41f8d18d42b6b2fedb694aa5508
--- /dev/null
+++ b/retinaface/utils/timer.py
@@ -0,0 +1,40 @@
+# --------------------------------------------------------
+# Fast R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+import time
+
+
+class Timer(object):
+ """A simple timer."""
+ def __init__(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
+
+ def tic(self):
+ # using time.time instead of time.clock because time time.clock
+ # does not normalize for multithreading
+ self.start_time = time.time()
+
+ def toc(self, average=True):
+ self.diff = time.time() - self.start_time
+ self.total_time += self.diff
+ self.calls += 1
+ self.average_time = self.total_time / self.calls
+ if average:
+ return self.average_time
+ else:
+ return self.diff
+
+ def clear(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
diff --git a/retinaface/weights/RBF_Final.pth b/retinaface/weights/RBF_Final.pth
new file mode 100644
index 0000000000000000000000000000000000000000..5f5e03c1d4b777defe0a6314075df27d12c30bea
--- /dev/null
+++ b/retinaface/weights/RBF_Final.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b22c45855bff5c4eb84e9973e2ee8a980639aa194ec66f33a87ce712f9d44c28
+size 1504452
diff --git a/retinaface/weights/mobilenet0.25_Final.pth b/retinaface/weights/mobilenet0.25_Final.pth
new file mode 100644
index 0000000000000000000000000000000000000000..36079ee388ef1adfa7a99b06142228d4773dd55a
--- /dev/null
+++ b/retinaface/weights/mobilenet0.25_Final.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed6fa7d45a5179f03cb5ce82c73b0c6711372968e53782e88395acee88d319e1
+size 1789735
diff --git a/retinaface/weights/mobilenet0.25_epoch_842.pth b/retinaface/weights/mobilenet0.25_epoch_842.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f4f4b66775704b99eebd2c055a0aa351f70b3c9e
--- /dev/null
+++ b/retinaface/weights/mobilenet0.25_epoch_842.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79208aec7020085085f364cd0a833f08b4083e635c2f71bea80dfe3523c4439e
+size 1841787
diff --git a/retinaface/weights/slim_Final.pth b/retinaface/weights/slim_Final.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f0e23bab25f54a561d58b28f71f0adc0147bf029
--- /dev/null
+++ b/retinaface/weights/slim_Final.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76cdb587b2cf5f91b3836b66a652a0f7a91891758cd4fb0b07e4b7527af1c651
+size 1427150
diff --git a/turn.py b/turn.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7e8a56988cd8962a7c3591a5ca8c559be7d149
--- /dev/null
+++ b/turn.py
@@ -0,0 +1,34 @@
+# Copied from streamlit-webrtc/sample_utils/turn.py
+import logging
+import os
+
+import streamlit as st
+from twilio.rest import Client
+
+logger = logging.getLogger(__name__)
+
+
+@st.cache_data # type: ignore
+def get_ice_servers():
+ """Use Twilio's TURN server because Streamlit Community Cloud has changed
+ its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501
+ We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too,
+ but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501
+ See https://github.com/whitphx/streamlit-webrtc/issues/1213
+ """
+
+ # Ref: https://www.twilio.com/docs/stun-turn/api
+ try:
+ account_sid = os.environ["TWILIO_ACCOUNT_SID"]
+ auth_token = os.environ["TWILIO_AUTH_TOKEN"]
+ except KeyError:
+ logger.warning(
+ "Twilio credentials are not set. Fallback to a free STUN server from Google." # noqa: E501
+ )
+ return [{"urls": ["stun:stun.l.google.com:19302"]}]
+
+ client = Client(account_sid, auth_token)
+
+ token = client.tokens.create()
+
+ return token.ice_servers