diff --git a/TPSMM/LICENSE b/TPSMM/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cde95a5110a80285e68865acd7ffda00faebed8b --- /dev/null +++ b/TPSMM/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 yoyo-nb + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/TPSMM/README.md b/TPSMM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..812a4d99249e970aab5321c9d542c0cc0975c544 --- /dev/null +++ b/TPSMM/README.md @@ -0,0 +1,98 @@ +# [CVPR2022] Thin-Plate Spline Motion Model for Image Animation + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) +![stars](https://img.shields.io/github/stars/yoyo-nb/Thin-Plate-Spline-Motion-Model.svg?style=flat) +![GitHub repo size](https://img.shields.io/github/repo-size/yoyo-nb/Thin-Plate-Spline-Motion-Model.svg) + +Source code of the CVPR'2022 paper "Thin-Plate Spline Motion Model for Image Animation" + +[**Paper**](https://arxiv.org/abs/2203.14367) **|** [**Supp**](https://cloud.tsinghua.edu.cn/f/f7b8573bb5b04583949f/?dl=1) + +### Example animation + +![vox](assets/vox.gif) +![ted](assets/ted.gif) + +**PS**: The paper trains the model for 100 epochs for a fair comparison. You can use more data and train for more epochs to get better performance. + + +### Web demo for animation +- Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model) +- Try the web demo for animation here: [![Replicate](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model/badge)](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model) +- Google Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DREfdpnaBhqISg0fuQlAAIwyGVn1loH_?usp=sharing) + +### Pre-trained models +- [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/30ab8765da364fefa101/) +- [Google Drive](https://drive.google.com/drive/folders/1pNDo1ODQIb5HVObRtCmubqJikmR7VVLT?usp=sharing) + +### Installation + +We support ```python3```.(Recommended version is Python 3.9). +To install the dependencies run: +```bash +pip install -r requirements.txt +``` + + +### YAML configs + +There are several configuration files one for each `dataset` in the `config` folder named as ```config/dataset_name.yaml```. + +See description of the parameters in the ```config/taichi-256.yaml```. + +### Datasets + +1) **MGif**. Follow [Monkey-Net](https://github.com/AliaksandrSiarohin/monkey-net). + +2) **TaiChiHD** and **VoxCeleb**. Follow instructions from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). + +3) **TED-talks**. Follow instructions from [MRAA](https://github.com/snap-research/articulated-animation). + + +### Training +To train a model on specific dataset run: +``` +CUDA_VISIBLE_DEVICES=0,1 python run.py --config config/dataset_name.yaml --device_ids 0,1 +``` +A log folder named after the timestamp will be created. Checkpoints, loss values, reconstruction results will be saved to this folder. + + +#### Training AVD network +To train a model on specific dataset run: +``` +CUDA_VISIBLE_DEVICES=0 python run.py --mode train_avd --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' --config config/dataset_name.yaml +``` +Checkpoints, loss values, reconstruction results will be saved to `{checkpoint_folder}`. + + + +### Evaluation on video reconstruction + +To evaluate the reconstruction performance run: +``` +CUDA_VISIBLE_DEVICES=0 python run.py --mode reconstruction --config config/dataset_name.yaml --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' +``` +The `reconstruction` subfolder will be created in `{checkpoint_folder}`. +The generated video will be stored to this folder, also generated videos will be stored in ```png``` subfolder in loss-less '.png' format for evaluation. +To compute metrics, follow instructions from [pose-evaluation](https://github.com/AliaksandrSiarohin/pose-evaluation). + + +### Image animation demo +- notebook: `demo.ipynb`, edit the config cell and run for image animation. +- python: +```bash +CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-256.yaml --checkpoint checkpoints/vox.pth.tar --source_image ./source.jpg --driving_video ./driving.mp4 +``` + +# Acknowledgments +The main code is based upon [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) and [MRAA](https://github.com/snap-research/articulated-animation) + +Thanks for the excellent works! + +And Thanks to: + +- [@chenxwh](https://github.com/chenxwh): Add Web Demo & Docker environment [![Replicate](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model/badge)](https://replicate.com/yoyo-nb/thin-plate-spline-motion-model) + +- [@TalkUHulk](https://github.com/TalkUHulk): The C++/Python demo is provided in [Image-Animation-Turbo-Boost](https://github.com/TalkUHulk/Image-Animation-Turbo-Boost) + +- [@AK391](https://github.com/AK391): Add huggingface web demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/CVPR/Image-Animation-using-Thin-Plate-Spline-Motion-Model) \ No newline at end of file diff --git a/TPSMM/assets/source.png b/TPSMM/assets/source.png new file mode 100644 index 0000000000000000000000000000000000000000..241ebbd038565e5a78cb5c5e738f6a8820b18c47 Binary files /dev/null and b/TPSMM/assets/source.png differ diff --git a/TPSMM/assets/source1.png b/TPSMM/assets/source1.png new file mode 100644 index 0000000000000000000000000000000000000000..c402a3391ecfba923ef842a8fab9097b31289992 Binary files /dev/null and b/TPSMM/assets/source1.png differ diff --git a/TPSMM/assets/source2.jpg b/TPSMM/assets/source2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0f41c2eab8b498fb40fe81d5e2558d29b0cbf365 Binary files /dev/null and b/TPSMM/assets/source2.jpg differ diff --git a/TPSMM/augmentation.py b/TPSMM/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..df77004a1b7093c0992c970ed0a337b073ddfe86 --- /dev/null +++ b/TPSMM/augmentation.py @@ -0,0 +1,344 @@ +""" +Code from https://github.com/hassony2/torch_videovision +""" + +import numbers + +import random +import numpy as np +import PIL + +from skimage.transform import resize, rotate +import torchvision + +import warnings + +from skimage import img_as_ubyte, img_as_float + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def pad_clip(clip, h, w): + im_h, im_w = clip[0].shape[:2] + pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) + pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) + + return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + + scaled = [ + resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, + mode='constant', anti_aliasing=True) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.NEAREST + else: + pil_inter = PIL.Image.BILINEAR + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +class RandomFlip(object): + def __init__(self, time_flip=False, horizontal_flip=False): + self.time_flip = time_flip + self.horizontal_flip = horizontal_flip + + def __call__(self, clip): + if random.random() < 0.5 and self.time_flip: + return clip[::-1] + if random.random() < 0.5 and self.horizontal_flip: + return [np.fliplr(img) for img in clip] + + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = resize_clip( + clip, new_size, interpolation=self.interpolation) + + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of videos + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + clip = pad_clip(clip, h, w) + im_h, im_w = clip.shape[1:3] + x1 = 0 if h == im_h else random.randint(0, im_w - w) + y1 = 0 if w == im_w else random.randint(0, im_h - h) + cropped = crop_clip(clip, y1, x1, h, w) + + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation and hue of the clip + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, + img_as_float] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + jittered_clip = [] + for img in clip: + jittered_img = img + for func in img_transforms: + jittered_img = func(jittered_img) + jittered_clip.append(jittered_img.astype('float32')) + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all videos + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class AllAugmentationTransform: + def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None): + self.transforms = [] + + if flip_param is not None: + self.transforms.append(RandomFlip(**flip_param)) + + if rotation_param is not None: + self.transforms.append(RandomRotation(**rotation_param)) + + if resize_param is not None: + self.transforms.append(RandomResize(**resize_param)) + + if crop_param is not None: + self.transforms.append(RandomCrop(**crop_param)) + + if jitter_param is not None: + self.transforms.append(ColorJitter(**jitter_param)) + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip diff --git a/TPSMM/cog.yaml b/TPSMM/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..203a577ae7a37e77e9a349946cfa4e4fb9638590 --- /dev/null +++ b/TPSMM/cog.yaml @@ -0,0 +1,40 @@ +build: + cuda: "11.0" + gpu: true + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + - "ninja-build" + python_packages: + - "ipython==7.21.0" + - "torch==1.10.1" + - "torchvision==0.11.2" + - "cffi==1.14.6" + - "cycler==0.10.0" + - "decorator==5.1.0" + - "face-alignment==1.3.5" + - "imageio==2.9.0" + - "imageio-ffmpeg==0.4.5" + - "kiwisolver==1.3.2" + - "matplotlib==3.4.3" + - "networkx==2.6.3" + - "numpy==1.20.3" + - "pandas==1.3.3" + - "Pillow==8.3.2" + - "pycparser==2.20" + - "pyparsing==2.4.7" + - "python-dateutil==2.8.2" + - "pytz==2021.1" + - "PyWavelets==1.1.1" + - "PyYAML==5.4.1" + - "scikit-image==0.18.3" + - "scikit-learn==1.0" + - "scipy==1.7.1" + - "six==1.16.0" + - "tqdm==4.62.3" + - "cmake==3.21.3" + run: + - pip install dlib + +predict: "predict.py:Predictor" diff --git a/TPSMM/config/mgif-256.yaml b/TPSMM/config/mgif-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4cd2011114f0b8ec69a76b4a210e134b48dea2cc --- /dev/null +++ b/TPSMM/config/mgif-256.yaml @@ -0,0 +1,75 @@ +dataset_params: + root_dir: ../moving-gif + frame_shape: null + id_sampling: False + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + crop_param: + size: [256, 256] + resize_param: + ratio: [0.9, 1.1] + jitter_param: + hue: 0.5 + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: False + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 3 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 100 + num_repeats: 50 + epoch_milestones: [70, 90] + lr_generator: 2.0e-4 + batch_size: 28 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 12 + checkpoint_freq: 50 + dropout_epoch: 35 + dropout_maxp: 0.5 + dropout_startp: 0.2 + dropout_inc_epoch: 10 + bg_start: 0 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 10 + +train_avd_params: + num_epochs: 100 + num_repeats: 50 + batch_size: 256 + dataloader_workers: 24 + checkpoint_freq: 10 + epoch_milestones: [70, 90] + lr: 1.0e-3 + lambda_shift: 1 + lambda_affine: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/TPSMM/config/taichi-256.yaml b/TPSMM/config/taichi-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2eb1d3a74634b5d1fc78cc5d6d368c7528d45750 --- /dev/null +++ b/TPSMM/config/taichi-256.yaml @@ -0,0 +1,134 @@ +# Dataset parameters +# Each dataset should contain 2 folders train and test +# Each video can be represented as: +# - an image of concatenated frames +# - '.mp4' or '.gif' +# - folder with all frames from a specific video +# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following +# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube +# video id. +dataset_params: + # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. + root_dir: ../taichi + # Image shape, needed for staked .png format. + frame_shape: null + # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person. + # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False) + # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335 + id_sampling: True + # Augmentation parameters see augmentation.py for all posible augmentations + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + +# Defines model architecture +model_params: + common_params: + # Number of TPS transformation + num_tps: 10 + # Number of channels per image + num_channels: 3 + # Whether to estimate affine background transformation + bg: True + # Whether to estimate the multi-resolution occlusion masks + multi_mask: True + generator_params: + # Number of features mutliplier + block_expansion: 64 + # Maximum allowed number of features + max_features: 512 + # Number of downsampling blocks and Upsampling blocks. + num_down_blocks: 3 + dense_motion_params: + # Number of features mutliplier + block_expansion: 64 + # Maximum allowed number of features + max_features: 1024 + # Number of block in Unet. + num_blocks: 5 + # Optical flow is predicted on smaller images for better performance, + # scale_factor=0.25 means that 256x256 image will be resized to 64x64 + scale_factor: 0.25 + avd_network_params: + # Bottleneck for identity branch + id_bottle_size: 128 + # Bottleneck for pose branch + pose_bottle_size: 128 + +# Parameters of training +train_params: + # Number of training epochs + num_epochs: 100 + # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. + # Thus effectivlly with num_repeats=100 each epoch is 100 times larger. + num_repeats: 150 + # Drop learning rate by 10 times after this epochs + epoch_milestones: [70, 90] + # Initial learing rate for all modules + lr_generator: 2.0e-4 + batch_size: 28 + # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, + # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. + scales: [1, 0.5, 0.25, 0.125] + # Dataset preprocessing cpu workers + dataloader_workers: 12 + # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs. + checkpoint_freq: 50 + # Parameters of dropout + # The first dropout_epoch training uses dropout operation + dropout_epoch: 35 + # The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs + dropout_maxp: 0.7 + dropout_startp: 0.0 + dropout_inc_epoch: 10 + # Estimate affine background transformation from the bg_start epoch. + bg_start: 0 + # Parameters of random TPS transformation for equivariance loss + transform_params: + # Sigma for affine part + sigma_affine: 0.05 + # Sigma for deformation part + sigma_tps: 0.005 + # Number of point in the deformation grid + points_tps: 5 + loss_weights: + # Weights for perceptual loss. + perceptual: [10, 10, 10, 10, 10] + # Weights for value equivariance. + equivariance_value: 10 + # Weights for warp loss. + warp_loss: 10 + # Weights for bg loss. + bg: 10 + +# Parameters of training (animation-via-disentanglement) +train_avd_params: + # Number of training epochs, visualization is produced after each epoch. + num_epochs: 100 + # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. + # Thus effectively with num_repeats=100 each epoch is 100 times larger. + num_repeats: 150 + # Batch size. + batch_size: 256 + # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs. + checkpoint_freq: 10 + # Dataset preprocessing cpu workers + dataloader_workers: 24 + # Drop learning rate 10 times after this epochs + epoch_milestones: [70, 90] + # Initial learning rate + lr: 1.0e-3 + # Weights for equivariance loss. + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/TPSMM/config/ted-384.yaml b/TPSMM/config/ted-384.yaml new file mode 100644 index 0000000000000000000000000000000000000000..007b9126823229c70459f51b173773f66f9f2be5 --- /dev/null +++ b/TPSMM/config/ted-384.yaml @@ -0,0 +1,73 @@ +dataset_params: + root_dir: ../TED384-v2 + frame_shape: null + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 3 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 100 + num_repeats: 150 + epoch_milestones: [70, 90] + lr_generator: 2.0e-4 + batch_size: 12 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 6 + checkpoint_freq: 50 + dropout_epoch: 35 + dropout_maxp: 0.5 + dropout_startp: 0.0 + dropout_inc_epoch: 10 + bg_start: 0 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 10 + +train_avd_params: + num_epochs: 30 + num_repeats: 500 + batch_size: 256 + dataloader_workers: 24 + checkpoint_freq: 10 + epoch_milestones: [20, 25] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' diff --git a/TPSMM/config/vox-256.yaml b/TPSMM/config/vox-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..658a8163e92574f2b52cc2bbc6886fb1603d7c55 --- /dev/null +++ b/TPSMM/config/vox-256.yaml @@ -0,0 +1,74 @@ +dataset_params: + root_dir: ../vox + frame_shape: null + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 3 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 100 + num_repeats: 75 + epoch_milestones: [70, 90] + lr_generator: 2.0e-4 + batch_size: 28 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 12 + checkpoint_freq: 50 + dropout_epoch: 35 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 10 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 10 + +train_avd_params: + num_epochs: 200 + num_repeats: 300 + batch_size: 256 + dataloader_workers: 24 + checkpoint_freq: 50 + epoch_milestones: [140, 180] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/TPSMM/demo.ipynb b/TPSMM/demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8bfccd9bfa013a06a2301373684ed82ead07246f --- /dev/null +++ b/TPSMM/demo.ipynb @@ -0,0 +1,5113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Config**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# edit the config\n", + "device = torch.device('cuda:0')\n", + "dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']\n", + "source_image_path = './assets/source.png'\n", + "driving_video_path = './assets/driving.mp4'\n", + "output_video_path = './generated.mp4'\n", + "config_path = 'config/vox-256.yaml'\n", + "checkpoint_path = 'checkpoints/vox.pth.tar'\n", + "predict_mode = 'relative' # ['standard', 'relative', 'avd']\n", + "find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result\n", + "\n", + "pixel = 256 # for vox, taichi and mgif, the resolution is 256*256\n", + "if(dataset_name == 'ted'): # for ted, the resolution is 384*384\n", + " pixel = 384\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Read image and video**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 453 + }, + "id": "Oxi6-riLOgnm", + "outputId": "d38a8850-9eb1-4de4-9bf2-24cbd847ca1f" + }, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import imageio\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.animation as animation\n", + "from skimage.transform import resize\n", + "from IPython.display import HTML\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "source_image = imageio.imread(source_image_path)\n", + "reader = imageio.get_reader(driving_video_path)\n", + "\n", + "\n", + "source_image = resize(source_image, (pixel, pixel))[..., :3]\n", + "\n", + "fps = reader.get_meta_data()['fps']\n", + "driving_video = []\n", + "try:\n", + " for im in reader:\n", + " driving_video.append(im)\n", + "except RuntimeError:\n", + " pass\n", + "reader.close()\n", + "\n", + "driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]\n", + "\n", + "def display(source, driving, generated=None):\n", + " fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n", + "\n", + " ims = []\n", + " for i in range(len(driving)):\n", + " cols = [source]\n", + " cols.append(driving[i])\n", + " if generated is not None:\n", + " cols.append(generated[i])\n", + " im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n", + " plt.axis('off')\n", + " ims.append([im])\n", + "\n", + " ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n", + " plt.close()\n", + " return ani\n", + " \n", + "\n", + "HTML(display(source_image, driving_video).to_html5_video())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xjM7ubVfWrwT" + }, + "source": [ + "**Create a model and load checkpoints**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "3FQiXqQPWt5B" + }, + "outputs": [], + "source": [ + "from demo import load_checkpoints\n", + "inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fdFdasHEj3t7" + }, + "source": [ + "**Perform image animation**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 471 + }, + "id": "SB12II11kF4c", + "outputId": "9e2274aa-fd55-4eed-cb50-bec72fcfb8b9" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:10<00:00, 15.69it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from demo import make_animation\n", + "from skimage import img_as_ubyte\n", + "\n", + "if predict_mode=='relative' and find_best_frame:\n", + " from demo import find_best_frame as _find\n", + " i = _find(source_image, driving_video, device.type=='cpu')\n", + " print (\"Best frame: \" + str(i))\n", + " driving_forward = driving_video[i:]\n", + " driving_backward = driving_video[:(i+1)][::-1]\n", + " predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", + " predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", + " predictions = predictions_backward[::-1] + predictions_forward[1:]\n", + "else:\n", + " predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", + "\n", + "#save resulting video\n", + "imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)\n", + "\n", + "HTML(display(source_image, driving_video, predictions).to_html5_video())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "name": "first-order-model-demo.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/TPSMM/demo.py b/TPSMM/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..011092e02b9a9221444f244a9e4a3fedcffa4cbf --- /dev/null +++ b/TPSMM/demo.py @@ -0,0 +1,180 @@ +import matplotlib +matplotlib.use('Agg') +import sys +import yaml +from argparse import ArgumentParser +from tqdm import tqdm +from scipy.spatial import ConvexHull +import numpy as np +import imageio +from skimage.transform import resize +from skimage import img_as_ubyte +import torch +from modules.inpainting_network import InpaintingNetwork +from modules.keypoint_detector import KPDetector +from modules.dense_motion import DenseMotionNetwork +from modules.avd_network import AVDNetwork + +if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9") + +def relative_kp(kp_source, kp_driving, kp_driving_initial): + + source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume + driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + + kp_new = {k: v for k, v in kp_driving.items()} + + kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp']) + kp_value_diff *= adapt_movement_scale + kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp'] + + return kp_new + +def load_checkpoints(config_path, checkpoint_path, device): + with open(config_path) as f: + config = yaml.full_load(f) + + inpainting = InpaintingNetwork(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + kp_detector = KPDetector(**config['model_params']['common_params']) + dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], + **config['model_params']['dense_motion_params']) + avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], + **config['model_params']['avd_network_params']) + kp_detector.to(device) + dense_motion_network.to(device) + inpainting.to(device) + avd_network.to(device) + + checkpoint = torch.load(checkpoint_path, map_location=device) + + inpainting.load_state_dict(checkpoint['inpainting_network']) + kp_detector.load_state_dict(checkpoint['kp_detector']) + dense_motion_network.load_state_dict(checkpoint['dense_motion_network']) + if 'avd_network' in checkpoint: + avd_network.load_state_dict(checkpoint['avd_network']) + + inpainting.eval() + kp_detector.eval() + dense_motion_network.eval() + avd_network.eval() + + return inpainting, kp_detector, dense_motion_network, avd_network + + +def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'): + assert mode in ['standard', 'relative', 'avd'] + with torch.no_grad(): + predictions = [] + source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) + source = source.to(device) + driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device) + kp_source = kp_detector(source) + kp_driving_initial = kp_detector(driving[:, :, 0]) + + for frame_idx in tqdm(range(driving.shape[2])): + driving_frame = driving[:, :, frame_idx] + driving_frame = driving_frame.to(device) + kp_driving = kp_detector(driving_frame) + if mode == 'standard': + kp_norm = kp_driving + elif mode=='relative': + kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving, + kp_driving_initial=kp_driving_initial) + elif mode == 'avd': + kp_norm = avd_network(kp_source, kp_driving) + dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm, + kp_source=kp_source, bg_param = None, + dropout_flag = False) + out = inpainting_network(source, dense_motion) + + predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) + return predictions + + +def find_best_frame(source, driving, cpu): + import face_alignment + + def normalize_kp(kp): + kp = kp - kp.mean(axis=0, keepdims=True) + area = ConvexHull(kp[:, :2]).volume + area = np.sqrt(area) + kp[:, :2] = kp[:, :2] / area + return kp + + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, + device= 'cpu' if cpu else 'cuda') + kp_source = fa.get_landmarks(255 * source)[0] + kp_source = normalize_kp(kp_source) + norm = float('inf') + frame_num = 0 + for i, image in tqdm(enumerate(driving)): + try: + kp_driving = fa.get_landmarks(255 * image)[0] + kp_driving = normalize_kp(kp_driving) + new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() + if new_norm < norm: + norm = new_norm + frame_num = i + except: + pass + return frame_num + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--config", required=True, help="path to config") + parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore") + + parser.add_argument("--source_image", default='./assets/source.png', help="path to source image") + parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video") + parser.add_argument("--result_video", default='./result.mp4', help="path to output") + + parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))), + help='Shape of image, that the model was trained on.') + + parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result") + + parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", + help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") + + parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") + + opt = parser.parse_args() + + source_image = imageio.imread(opt.source_image) + reader = imageio.get_reader(opt.driving_video) + fps = reader.get_meta_data()['fps'] + driving_video = [] + try: + for im in reader: + driving_video.append(im) + except RuntimeError: + pass + reader.close() + + if opt.cpu: + device = torch.device('cpu') + else: + device = torch.device('cuda') + + source_image = resize(source_image, opt.img_shape)[..., :3] + driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video] + inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device) + + if opt.find_best_frame: + i = find_best_frame(source_image, driving_video, opt.cpu) + print ("Best frame: " + str(i)) + driving_forward = driving_video[i:] + driving_backward = driving_video[:(i+1)][::-1] + predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode) + predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode) + predictions = predictions_backward[::-1] + predictions_forward[1:] + else: + predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode) + + predictions =[img_as_ubyte(frame) for frame in predictions] + imageio.mimsave(opt.result_video, predictions, fps=fps) + diff --git a/TPSMM/frames_dataset.py b/TPSMM/frames_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2328c6a0af028518fe5412ebc75b2b8a019d77c8 --- /dev/null +++ b/TPSMM/frames_dataset.py @@ -0,0 +1,173 @@ +import os +from skimage import io, img_as_float32 +from skimage.color import gray2rgb +from sklearn.model_selection import train_test_split +from imageio import mimread +from skimage.transform import resize +import numpy as np +from torch.utils.data import Dataset +from augmentation import AllAugmentationTransform +import glob +from functools import partial + + +def read_video(name, frame_shape): + """ + Read video which can be: + - an image of concatenated frames + - '.mp4' and'.gif' + - folder with videos + """ + + if os.path.isdir(name): + frames = sorted(os.listdir(name)) + num_frames = len(frames) + video_array = np.array( + [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]) + elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): + image = io.imread(name) + + if len(image.shape) == 2 or image.shape[2] == 1: + image = gray2rgb(image) + + if image.shape[2] == 4: + image = image[..., :3] + + image = img_as_float32(image) + + video_array = np.moveaxis(image, 1, 0) + + video_array = video_array.reshape((-1,) + frame_shape) + video_array = np.moveaxis(video_array, 1, 2) + elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): + video = mimread(name) + if len(video[0].shape) == 2: + video = [gray2rgb(frame) for frame in video] + if frame_shape is not None: + video = np.array([resize(frame, frame_shape) for frame in video]) + video = np.array(video) + if video.shape[-1] == 4: + video = video[..., :3] + video_array = img_as_float32(video) + else: + raise Exception("Unknown file extensions %s" % name) + + return video_array + + +class FramesDataset(Dataset): + """ + Dataset of videos, each video can be represented as: + - an image of concatenated frames + - '.mp4' or '.gif' + - folder with all frames + """ + + def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, + random_seed=0, pairs_list=None, augmentation_params=None): + self.root_dir = root_dir + self.videos = os.listdir(root_dir) + self.frame_shape = frame_shape + print(self.frame_shape) + self.pairs_list = pairs_list + self.id_sampling = id_sampling + + if os.path.exists(os.path.join(root_dir, 'train')): + assert os.path.exists(os.path.join(root_dir, 'test')) + print("Use predefined train-test split.") + if id_sampling: + train_videos = {os.path.basename(video).split('#')[0] for video in + os.listdir(os.path.join(root_dir, 'train'))} + train_videos = list(train_videos) + else: + train_videos = os.listdir(os.path.join(root_dir, 'train')) + test_videos = os.listdir(os.path.join(root_dir, 'test')) + self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') + else: + print("Use random train-test split.") + train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) + + if is_train: + self.videos = train_videos + else: + self.videos = test_videos + + self.is_train = is_train + + if self.is_train: + self.transform = AllAugmentationTransform(**augmentation_params) + else: + self.transform = None + + def __len__(self): + return len(self.videos) + + def __getitem__(self, idx): + + if self.is_train and self.id_sampling: + name = self.videos[idx] + path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) + else: + name = self.videos[idx] + path = os.path.join(self.root_dir, name) + + video_name = os.path.basename(path) + if self.is_train and os.path.isdir(path): + + frames = os.listdir(path) + num_frames = len(frames) + frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) + + if self.frame_shape is not None: + resize_fn = partial(resize, output_shape=self.frame_shape) + else: + resize_fn = img_as_float32 + + if type(frames[0]) is bytes: + video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in + frame_idx] + else: + video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] + else: + + video_array = read_video(path, frame_shape=self.frame_shape) + + num_frames = len(video_array) + frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range( + num_frames) + video_array = video_array[frame_idx] + + + if self.transform is not None: + video_array = self.transform(video_array) + + out = {} + if self.is_train: + source = np.array(video_array[0], dtype='float32') + driving = np.array(video_array[1], dtype='float32') + + out['driving'] = driving.transpose((2, 0, 1)) + out['source'] = source.transpose((2, 0, 1)) + else: + video = np.array(video_array, dtype='float32') + out['video'] = video.transpose((3, 0, 1, 2)) + + out['name'] = video_name + return out + + +class DatasetRepeater(Dataset): + """ + Pass several times over the same dataset for better i/o performance + """ + + def __init__(self, dataset, num_repeats=100): + self.dataset = dataset + self.num_repeats = num_repeats + + def __len__(self): + return self.num_repeats * self.dataset.__len__() + + def __getitem__(self, idx): + return self.dataset[idx % self.dataset.__len__()] + diff --git a/TPSMM/logger.py b/TPSMM/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..825fbe7b20e023db239ead5e4adc06368f78f76d --- /dev/null +++ b/TPSMM/logger.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +import torch.nn.functional as F +import imageio + +import os +from skimage.draw import circle + +import matplotlib.pyplot as plt +import collections + + +class Logger: + def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'): + + self.loss_list = [] + self.cpk_dir = log_dir + self.visualizations_dir = os.path.join(log_dir, 'train-vis') + if not os.path.exists(self.visualizations_dir): + os.makedirs(self.visualizations_dir) + self.log_file = open(os.path.join(log_dir, log_file_name), 'a') + self.zfill_num = zfill_num + self.visualizer = Visualizer(**visualizer_params) + self.checkpoint_freq = checkpoint_freq + self.epoch = 0 + self.best_loss = float('inf') + self.names = None + + def log_scores(self, loss_names): + loss_mean = np.array(self.loss_list).mean(axis=0) + + loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) + loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string + + print(loss_string, file=self.log_file) + self.loss_list = [] + self.log_file.flush() + + def visualize_rec(self, inp, out): + image = self.visualizer.visualize(inp['driving'], inp['source'], out) + imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image) + + def save_cpk(self, emergent=False): + cpk = {k: v.state_dict() for k, v in self.models.items()} + cpk['epoch'] = self.epoch + cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num)) + if not (os.path.exists(cpk_path) and emergent): + torch.save(cpk, cpk_path) + + @staticmethod + def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None, + bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None, + optimizer_avd=None): + checkpoint = torch.load(checkpoint_path) + if inpainting_network is not None: + inpainting_network.load_state_dict(checkpoint['inpainting_network']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if bg_predictor is not None and 'bg_predictor' in checkpoint: + bg_predictor.load_state_dict(checkpoint['bg_predictor']) + if dense_motion_network is not None: + dense_motion_network.load_state_dict(checkpoint['dense_motion_network']) + if avd_network is not None: + if 'avd_network' in checkpoint: + avd_network.load_state_dict(checkpoint['avd_network']) + if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint: + optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor']) + if optimizer is not None and 'optimizer' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + if optimizer_avd is not None: + if 'optimizer_avd' in checkpoint: + optimizer_avd.load_state_dict(checkpoint['optimizer_avd']) + epoch = -1 + if 'epoch' in checkpoint: + epoch = checkpoint['epoch'] + return epoch + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if 'models' in self.__dict__: + self.save_cpk() + self.log_file.close() + + def log_iter(self, losses): + losses = collections.OrderedDict(losses.items()) + self.names = list(losses.keys()) + self.loss_list.append(list(losses.values())) + + def log_epoch(self, epoch, models, inp, out): + self.epoch = epoch + self.models = models + if (self.epoch + 1) % self.checkpoint_freq == 0: + self.save_cpk() + self.log_scores(self.names) + self.visualize_rec(inp, out) + + +class Visualizer: + def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'): + self.kp_size = kp_size + self.draw_border = draw_border + self.colormap = plt.get_cmap(colormap) + + def draw_image_with_kp(self, image, kp_array): + image = np.copy(image) + spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] + kp_array = spatial_size * (kp_array + 1) / 2 + num_kp = kp_array.shape[0] + for kp_ind, kp in enumerate(kp_array): + rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2]) + image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] + return image + + def create_image_column_with_kp(self, images, kp): + image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) + return self.create_image_column(image_array) + + def create_image_column(self, images): + if self.draw_border: + images = np.copy(images) + images[:, :, [0, -1]] = (1, 1, 1) + images[:, :, [0, -1]] = (1, 1, 1) + return np.concatenate(list(images), axis=0) + + def create_image_grid(self, *args): + out = [] + for arg in args: + if type(arg) == tuple: + out.append(self.create_image_column_with_kp(arg[0], arg[1])) + else: + out.append(self.create_image_column(arg)) + return np.concatenate(out, axis=1) + + def visualize(self, driving, source, out): + images = [] + + # Source image with keypoints + source = source.data.cpu() + kp_source = out['kp_source']['fg_kp'].data.cpu().numpy() + source = np.transpose(source, [0, 2, 3, 1]) + images.append((source, kp_source)) + + # Equivariance visualization + if 'transformed_frame' in out: + transformed = out['transformed_frame'].data.cpu().numpy() + transformed = np.transpose(transformed, [0, 2, 3, 1]) + transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy() + images.append((transformed, transformed_kp)) + + # Driving image with keypoints + kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy() + driving = driving.data.cpu().numpy() + driving = np.transpose(driving, [0, 2, 3, 1]) + images.append((driving, kp_driving)) + + # Deformed image + if 'deformed' in out: + deformed = out['deformed'].data.cpu().numpy() + deformed = np.transpose(deformed, [0, 2, 3, 1]) + images.append(deformed) + + # Result with and without keypoints + prediction = out['prediction'].data.cpu().numpy() + prediction = np.transpose(prediction, [0, 2, 3, 1]) + if 'kp_norm' in out: + kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy() + images.append((prediction, kp_norm)) + images.append(prediction) + + + ## Occlusion map + if 'occlusion_map' in out: + for i in range(len(out['occlusion_map'])): + occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1) + occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() + occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) + images.append(occlusion_map) + + # Deformed images according to each individual transform + if 'deformed_source' in out: + full_mask = [] + for i in range(out['deformed_source'].shape[1]): + image = out['deformed_source'][:, i].data.cpu() + # import ipdb;ipdb.set_trace() + image = F.interpolate(image, size=source.shape[1:3]) + mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1) + mask = F.interpolate(mask, size=source.shape[1:3]) + image = np.transpose(image.numpy(), (0, 2, 3, 1)) + mask = np.transpose(mask.numpy(), (0, 2, 3, 1)) + + if i != 0: + color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3] + else: + color = np.array((0, 0, 0)) + + color = color.reshape((1, 1, 1, 3)) + + images.append(image) + if i != 0: + images.append(mask * color) + else: + images.append(mask) + + full_mask.append(mask * color) + + images.append(sum(full_mask)) + + image = self.create_image_grid(*images) + image = (255 * image).astype(np.uint8) + return image diff --git a/TPSMM/modules/avd_network.py b/TPSMM/modules/avd_network.py new file mode 100644 index 0000000000000000000000000000000000000000..e62937ebc7d00a09f0e10ab9abda038cdaeaaf54 --- /dev/null +++ b/TPSMM/modules/avd_network.py @@ -0,0 +1,65 @@ + +import torch +from torch import nn + + +class AVDNetwork(nn.Module): + """ + Animation via Disentanglement network + """ + + def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64): + super(AVDNetwork, self).__init__() + input_size = 5*2 * num_tps + self.num_tps = num_tps + + self.id_encoder = nn.Sequential( + nn.Linear(input_size, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Linear(256, 512), + nn.BatchNorm1d(512), + nn.ReLU(inplace=True), + nn.Linear(512, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(inplace=True), + nn.Linear(1024, id_bottle_size) + ) + + self.pose_encoder = nn.Sequential( + nn.Linear(input_size, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Linear(256, 512), + nn.BatchNorm1d(512), + nn.ReLU(inplace=True), + nn.Linear(512, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(inplace=True), + nn.Linear(1024, pose_bottle_size) + ) + + self.decoder = nn.Sequential( + nn.Linear(pose_bottle_size + id_bottle_size, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, input_size) + ) + + def forward(self, kp_source, kp_random): + + bs = kp_source['fg_kp'].shape[0] + + pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1)) + id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1)) + + rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1)) + + rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)} + return rec diff --git a/TPSMM/modules/bg_motion_predictor.py b/TPSMM/modules/bg_motion_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..875f4bc5a153545bfa116e0cf42807ce18e01bc2 --- /dev/null +++ b/TPSMM/modules/bg_motion_predictor.py @@ -0,0 +1,24 @@ +from torch import nn +import torch +from torchvision import models + +class BGMotionPredictor(nn.Module): + """ + Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1] + """ + + def __init__(self): + super(BGMotionPredictor, self).__init__() + self.bg_encoder = models.resnet18(pretrained=False) + self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + num_features = self.bg_encoder.fc.in_features + self.bg_encoder.fc = nn.Linear(num_features, 6) + self.bg_encoder.fc.weight.data.zero_() + self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) + + def forward(self, source_image, driving_image): + bs = source_image.shape[0] + out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type()) + prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1)) + out[:, :2, :] = prediction.view(bs, 2, 3) + return out diff --git a/TPSMM/modules/dense_motion.py b/TPSMM/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..ced509a9cc9d435c523b496367862fa0f5974f59 --- /dev/null +++ b/TPSMM/modules/dense_motion.py @@ -0,0 +1,164 @@ +from torch import nn +import torch.nn.functional as F +import torch +from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian +from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS +import math + +class DenseMotionNetwork(nn.Module): + """ + Module that estimating an optical flow and multi-resolution occlusion masks + from K TPS transformations and an affine transformation. + """ + + def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels, + scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01): + super(DenseMotionNetwork, self).__init__() + + if scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, scale_factor) + self.scale_factor = scale_factor + self.multi_mask = multi_mask + + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1), + max_features=max_features, num_blocks=num_blocks) + + hourglass_output_size = self.hourglass.out_channels + self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3)) + + if multi_mask: + up = [] + self.up_nums = int(math.log(1/scale_factor, 2)) + self.occlusion_num = 4 + + channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)] + for i in range(self.up_nums): + up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1)) + self.up = nn.ModuleList(up) + + channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]] + for i in range(self.up_nums): + channel.append(hourglass_output_size[-1]//(2**(i+1))) + occlusion = [] + + for i in range(self.occlusion_num): + occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3))) + self.occlusion = nn.ModuleList(occlusion) + else: + occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))] + self.occlusion = nn.ModuleList(occlusion) + + self.num_tps = num_tps + self.bg = bg + self.kp_variance = kp_variance + + + def create_heatmap_representations(self, source_image, kp_driving, kp_source): + + spatial_size = source_image.shape[2:] + gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance) + gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance) + heatmap = gaussian_driving - gaussian_source + + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device) + heatmap = torch.cat([zeros, heatmap], dim=1) + + return heatmap + + def create_transformations(self, source_image, kp_driving, kp_source, bg_param): + # K TPS transformaions + bs, _, h, w = source_image.shape + kp_1 = kp_driving['fg_kp'] + kp_2 = kp_source['fg_kp'] + kp_1 = kp_1.view(bs, -1, 5, 2) + kp_2 = kp_2.view(bs, -1, 5, 2) + trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2) + driving_to_source = trans.transform_frame(source_image) + + identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device) + identity_grid = identity_grid.view(1, 1, h, w, 2) + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) + + # affine background transformation + if not (bg_param is None): + identity_grid = to_homogeneous(identity_grid) + identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1) + identity_grid = from_homogeneous(identity_grid) + + transformations = torch.cat([identity_grid, driving_to_source], dim=1) + return transformations + + def create_deformed_source_image(self, source_image, transformations): + + bs, _, h, w = source_image.shape + source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1) + source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w) + transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1)) + deformed = F.grid_sample(source_repeat, transformations, align_corners=True) + deformed = deformed.view((bs, self.num_tps+1, -1, h, w)) + return deformed + + def dropout_softmax(self, X, P): + ''' + Dropout for TPS transformations. Eq(7) and Eq(8) in the paper. + ''' + drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device) + drop[..., 0] = 1 + drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1) + + maxx = X.max(1).values.unsqueeze_(1) + X = X - maxx + X_exp = X.exp() + X[:,1:,...] /= (1-P) + mask_bool =(drop == 0) + X_exp = X_exp.masked_fill(mask_bool, 0) + partition = X_exp.sum(dim=1, keepdim=True) + 1e-6 + return X_exp / partition + + def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0): + if self.scale_factor != 1: + source_image = self.down(source_image) + + bs, _, h, w = source_image.shape + + out_dict = dict() + heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) + transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param) + deformed_source = self.create_deformed_source_image(source_image, transformations) + out_dict['deformed_source'] = deformed_source + # out_dict['transformations'] = transformations + deformed_source = deformed_source.view(bs,-1,h,w) + input = torch.cat([heatmap_representation, deformed_source], dim=1) + input = input.view(bs, -1, h, w) + + prediction = self.hourglass(input, mode = 1) + + contribution_maps = self.maps(prediction[-1]) + if(dropout_flag): + contribution_maps = self.dropout_softmax(contribution_maps, dropout_p) + else: + contribution_maps = F.softmax(contribution_maps, dim=1) + out_dict['contribution_maps'] = contribution_maps + + # Combine the K+1 transformations + # Eq(6) in the paper + contribution_maps = contribution_maps.unsqueeze(2) + transformations = transformations.permute(0, 1, 4, 2, 3) + deformation = (transformations * contribution_maps).sum(dim=1) + deformation = deformation.permute(0, 2, 3, 1) + + out_dict['deformation'] = deformation # Optical Flow + + occlusion_map = [] + if self.multi_mask: + for i in range(self.occlusion_num-self.up_nums): + occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i]))) + prediction = prediction[-1] + for i in range(self.up_nums): + prediction = self.up[i](prediction) + occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction))) + else: + occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1]))) + + out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks + return out_dict diff --git a/TPSMM/modules/inpainting_network.py b/TPSMM/modules/inpainting_network.py new file mode 100644 index 0000000000000000000000000000000000000000..6b873bd21e2868b1959be8b92dee8b361ecbd6d8 --- /dev/null +++ b/TPSMM/modules/inpainting_network.py @@ -0,0 +1,127 @@ +import torch +from torch import nn +import torch.nn.functional as F +from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d +from modules.dense_motion import DenseMotionNetwork + + +class InpaintingNetwork(nn.Module): + """ + Inpaint the missing regions and reconstruct the Driving image. + """ + def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs): + super(InpaintingNetwork, self).__init__() + + self.num_down_blocks = num_down_blocks + self.multi_mask = multi_mask + self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + up_blocks = [] + resblock = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + decoder_in_feature = out_features * 2 + if i==num_down_blocks-1: + decoder_in_feature = out_features + up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1))) + resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1))) + resblock.append(ResBlock2d(decoder_in_feature, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks[::-1]) + self.resblock = nn.ModuleList(resblock[::-1]) + + self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) + self.num_channels = num_channels + + def deform_input(self, inp, deformation): + _, h_old, w_old, _ = deformation.shape + _, _, h, w = inp.shape + if h_old != h or w_old != w: + deformation = deformation.permute(0, 3, 1, 2) + deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True) + deformation = deformation.permute(0, 2, 3, 1) + return F.grid_sample(inp, deformation,align_corners=True) + + def occlude_input(self, inp, occlusion_map): + if not self.multi_mask: + if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True) + out = inp * occlusion_map + return out + + def forward(self, source_image, dense_motion): + out = self.first(source_image) + encoder_map = [out] + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + encoder_map.append(out) + + output_dict = {} + output_dict['contribution_maps'] = dense_motion['contribution_maps'] + output_dict['deformed_source'] = dense_motion['deformed_source'] + + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + + deformation = dense_motion['deformation'] + out_ij = self.deform_input(out.detach(), deformation) + out = self.deform_input(out, deformation) + + out_ij = self.occlude_input(out_ij, occlusion_map[0].detach()) + out = self.occlude_input(out, occlusion_map[0]) + + warped_encoder_maps = [] + warped_encoder_maps.append(out_ij) + + for i in range(self.num_down_blocks): + + out = self.resblock[2*i](out) + out = self.resblock[2*i+1](out) + out = self.up_blocks[i](out) + + encode_i = encoder_map[-(i+2)] + encode_ij = self.deform_input(encode_i.detach(), deformation) + encode_i = self.deform_input(encode_i, deformation) + + occlusion_ind = 0 + if self.multi_mask: + occlusion_ind = i+1 + encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach()) + encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind]) + warped_encoder_maps.append(encode_ij) + + if(i==self.num_down_blocks-1): + break + + out = torch.cat([out, encode_i], 1) + + deformed_source = self.deform_input(source_image, deformation) + output_dict["deformed"] = deformed_source + output_dict["warped_encoder_maps"] = warped_encoder_maps + + occlusion_last = occlusion_map[-1] + if not self.multi_mask: + occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True) + + out = out * (1 - occlusion_last) + encode_i + out = self.final(out) + out = torch.sigmoid(out) + out = out * (1 - occlusion_last) + deformed_source * occlusion_last + output_dict["prediction"] = out + + return output_dict + + def get_encode(self, driver_image, occlusion_map): + out = self.first(driver_image) + encoder_map = [] + encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach())) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out.detach()) + out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach()) + encoder_map.append(out_mask.detach()) + + return encoder_map + diff --git a/TPSMM/modules/keypoint_detector.py b/TPSMM/modules/keypoint_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a39a19458c75449c65d3e7810974eededb9d2d67 --- /dev/null +++ b/TPSMM/modules/keypoint_detector.py @@ -0,0 +1,27 @@ +from torch import nn +import torch +from torchvision import models + +class KPDetector(nn.Module): + """ + Predict K*5 keypoints. + """ + + def __init__(self, num_tps, **kwargs): + super(KPDetector, self).__init__() + self.num_tps = num_tps + + self.fg_encoder = models.resnet18(pretrained=False) + num_features = self.fg_encoder.fc.in_features + self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2) + + + def forward(self, image): + + fg_kp = self.fg_encoder(image) + bs, _, = fg_kp.shape + fg_kp = torch.sigmoid(fg_kp) + fg_kp = fg_kp * 2 - 1 + out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)} + + return out diff --git a/TPSMM/modules/model.py b/TPSMM/modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8df242b3f33e7c5e01d5e47628a81eb94a3f1964 --- /dev/null +++ b/TPSMM/modules/model.py @@ -0,0 +1,182 @@ +from torch import nn +import torch +import torch.nn.functional as F +from modules.util import AntiAliasInterpolation2d, TPS +from torchvision import models +import numpy as np + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. See Sec 3.3. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict + + +def detach_kp(kp): + return {key: value.detach() for key, value in kp.items()} + + +class GeneratorFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.inpainting_network = inpainting_network + self.dense_motion_network = dense_motion_network + + self.bg_predictor = None + if bg_predictor: + self.bg_predictor = bg_predictor + self.bg_start = train_params['bg_start'] + + self.train_params = train_params + self.scales = train_params['scales'] + + self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + self.dropout_epoch = train_params['dropout_epoch'] + self.dropout_maxp = train_params['dropout_maxp'] + self.dropout_inc_epoch = train_params['dropout_inc_epoch'] + self.dropout_startp =train_params['dropout_startp'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + + def forward(self, x, epoch): + kp_source = self.kp_extractor(x['source']) + kp_driving = self.kp_extractor(x['driving']) + bg_param = None + if self.bg_predictor: + if(epoch>=self.bg_start): + bg_param = self.bg_predictor(x['source'], x['driving']) + + if(epoch>=self.dropout_epoch): + dropout_flag = False + dropout_p = 0 + else: + # dropout_p will linearly increase from dropout_startp to dropout_maxp + dropout_flag = True + dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp) + + dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving, + kp_source=kp_source, bg_param = bg_param, + dropout_flag = dropout_flag, dropout_p = dropout_p) + generated = self.inpainting_network(x['source'], dense_motion) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + loss_values = {} + + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + + # reconstruction loss + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + # equivariance loss + if self.loss_weights['equivariance_value'] != 0: + transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params']) + transform_grid = transform_random.transform_frame(x['driving']) + transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True) + transformed_kp = self.kp_extractor(transformed_frame) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + warped = transform_random.warp_coordinates(transformed_kp['fg_kp']) + kp_d = kp_driving['fg_kp'] + value = torch.abs(kp_d - warped).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + # warp loss + if self.loss_weights['warp_loss'] != 0: + occlusion_map = generated['occlusion_map'] + encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map) + decode_map = generated['warped_encoder_maps'] + value = 0 + for i in range(len(encode_map)): + value += torch.abs(encode_map[i]-decode_map[-i-1]).mean() + + loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value + + # bg loss + if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0: + bg_param_reverse = self.bg_predictor(x['driving'], x['source']) + value = torch.matmul(bg_param, bg_param_reverse) + eye = torch.eye(3).view(1, 1, 3, 3).type(value.type()) + value = torch.abs(eye - value).mean() + loss_values['bg'] = self.loss_weights['bg'] * value + + return loss_values, generated diff --git a/TPSMM/modules/util.py b/TPSMM/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0a869916831d6804624009282196f6b2391dc280 --- /dev/null +++ b/TPSMM/modules/util.py @@ -0,0 +1,349 @@ +from torch import nn +import torch.nn.functional as F +import torch + + +class TPS: + ''' + TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss. + ''' + def __init__(self, mode, bs, **kwargs): + self.bs = bs + self.mode = mode + if mode == 'random': + noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) + self.theta = noise + torch.eye(2, 3).view(1, 2, 3) + self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) + self.control_points = self.control_points.unsqueeze(0) + self.control_params = torch.normal(mean=0, + std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) + elif mode == 'kp': + kp_1 = kwargs["kp_1"] + kp_2 = kwargs["kp_2"] + device = kp_1.device + kp_type = kp_1.type() + self.gs = kp_1.shape[1] + n = kp_1.shape[2] + K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2) + K = K**2 + K = K * torch.log(K+1e-9) + + one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type) + kp_1p = torch.cat([kp_1,one1], 3) + + zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type) + P = torch.cat([kp_1p, zero],2) + L = torch.cat([K,kp_1p.permute(0,1,3,2)],2) + L = torch.cat([L,P],3) + + zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type) + Y = torch.cat([kp_2, zero], 2) + one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01 + L = L + one + + param = torch.matmul(torch.inverse(L),Y) + self.theta = param[:,:,n:,:].permute(0,1,3,2) + + self.control_points = kp_1 + self.control_params = param[:,:,:n,:] + else: + raise Exception("Error TPS mode") + + def transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device) + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + shape = [self.bs, frame.shape[2], frame.shape[3], 2] + if self.mode == 'kp': + shape.insert(1, self.gs) + grid = self.warp_coordinates(grid).view(*shape) + return grid + + def warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()).to(coordinates.device) + control_points = self.control_points.type(coordinates.type()).to(coordinates.device) + control_params = self.control_params.type(coordinates.type()).to(coordinates.device) + + if self.mode == 'kp': + transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:] + + distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2) + + distances = distances ** 2 + result = distances.sum(-1) + result = result * torch.log(result + 1e-9) + result = torch.matmul(result.permute(0, 1, 3, 2), control_params) + transformed = transformed.permute(0, 1, 3, 2) + result + + elif self.mode == 'random': + theta = theta.unsqueeze(1) + transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] + transformed = transformed.squeeze(-1) + ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = ances ** 2 + + result = distances.sum(-1) + result = result * torch.log(result + 1e-9) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + else: + raise Exception("Error TPS mode") + + return transformed + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + + coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device) + number_of_leading_dimensions = len(kp.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2) + kp = kp.view(*shape) + + mean_sub = (coordinate_grid - kp) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = nn.InstanceNorm2d(in_features, affine=True) + self.norm2 = nn.InstanceNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.InstanceNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.InstanceNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.InstanceNorm2d(out_features, affine=True) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + #print('encoder:' ,outs[-1].shape) + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + #print('encoder:' ,outs[-1].shape) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + self.out_channels = [] + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + self.out_channels.append(in_filters) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_channels.append(block_expansion + in_features) + # self.out_filters = block_expansion + in_features + + def forward(self, x, mode = 0): + out = x.pop() + outs = [] + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + outs.append(out) + if(mode == 0): + return out + else: + return outs + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_channels = self.decoder.out_channels + # self.out_filters = self.decoder.out_filters + + def forward(self, x, mode = 0): + return self.decoder(self.encoder(x), mode) + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = F.interpolate(out, scale_factor=(self.scale, self.scale)) + + return out + + +def to_homogeneous(coordinates): + ones_shape = list(coordinates.shape) + ones_shape[-1] = 1 + ones = torch.ones(ones_shape).type(coordinates.type()) + + return torch.cat([coordinates, ones], dim=-1) + +def from_homogeneous(coordinates): + return coordinates[..., :2] / coordinates[..., 2:3] \ No newline at end of file diff --git a/TPSMM/pkgs/tpsmm.py b/TPSMM/pkgs/tpsmm.py new file mode 100644 index 0000000000000000000000000000000000000000..b5616ef0dc31a43b9229be3f58de73c1c48a8e31 --- /dev/null +++ b/TPSMM/pkgs/tpsmm.py @@ -0,0 +1,80 @@ +import os +import sys +import torch +import cv2 +import numpy as np +from skimage import img_as_ubyte +from skimage.transform import resize +pwd = os.path.dirname(os.path.realpath(__file__)) +sys.path.insert(1, os.path.join(pwd, "..")) + +from demo import relative_kp, load_checkpoints + + +class TPSMM: + def __init__(self): + self.device = torch.device("cuda") + self.inpainting, self.kp_detector, self.dense_motion_network, self.avd_network = load_checkpoints( + config_path=os.path.join(pwd, "../config/vox-256.yaml"), + checkpoint_path=os.path.join(pwd, "../pretrained/vox.pth.tar"), + device=self.device + ) + self.kp_driving_initial = None + + def process_source(self, src_img): + with torch.no_grad(): + src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) + src_img = resize(src_img, (256, 256)) + source_tensor = torch.tensor(src_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device) + kp_source = self.kp_detector(source_tensor) + + return source_tensor, kp_source + + + def gen_image(self, driving_img, source_tensor, kp_source): + with torch.no_grad(): + driving_img = cv2.cvtColor(driving_img, cv2.COLOR_BGR2RGB) + driving_img = resize(driving_img, (256, 256)) + driving_frame = torch.tensor(driving_img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(self.device) + + kp_driving = self.kp_detector(driving_frame) + if self.kp_driving_initial is None: + self.kp_driving_initial = kp_driving + kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving, + kp_driving_initial=self.kp_driving_initial) + dense_motion = self.dense_motion_network(source_image=source_tensor, + kp_driving=kp_norm, + kp_source=kp_source, bg_param=None, + dropout_flag=False) + out = self.inpainting(source_tensor, dense_motion) + out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] + out = img_as_ubyte(out) + out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) + + return out + + +if __name__ == "__main__": + tpsmm = TPSMM() + source_image = cv2.imread(os.path.join(pwd, "../assets/source1.png")) + cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4") + + source_tensor, kp_source = tpsmm.process_source(source_image) + + while True: + ret, frame = cap.read() + if frame is None: + break + + output = tpsmm.gen_image(frame, source_tensor, kp_source) + cv2.imshow("output", output) + key = cv2.waitKey(1) & 0xFF + if key == ord("q"): + break + # cv2.imwrite("./tmp.jpg", output) + + cv2.destroyAllWindows() + + + + \ No newline at end of file diff --git a/TPSMM/predict.py b/TPSMM/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..d615236a2d41a4f18bcba969f74e89cd3df05f50 --- /dev/null +++ b/TPSMM/predict.py @@ -0,0 +1,125 @@ +import os +import sys +sys.path.insert(0, "stylegan-encoder") +import tempfile +import warnings +import imageio +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from skimage.transform import resize +from skimage import img_as_ubyte +import torch +import torchvision.transforms as transforms +import dlib +from cog import BasePredictor, Path, Input + +from demo import load_checkpoints +from demo import make_animation +from ffhq_dataset.face_alignment import image_align +from ffhq_dataset.landmarks_detector import LandmarksDetector + + +warnings.filterwarnings("ignore") + + +PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") +LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat") + + +class Predictor(BasePredictor): + def setup(self): + + self.device = torch.device("cuda:0") + datasets = ["vox", "taichi", "ted", "mgif"] + ( + self.inpainting, + self.kp_detector, + self.dense_motion_network, + self.avd_network, + ) = ({}, {}, {}, {}) + for d in datasets: + ( + self.inpainting[d], + self.kp_detector[d], + self.dense_motion_network[d], + self.avd_network[d], + ) = load_checkpoints( + config_path=f"config/{d}-384.yaml" + if d == "ted" + else f"config/{d}-256.yaml", + checkpoint_path=f"checkpoints/{d}.pth.tar", + device=self.device, + ) + + def predict( + self, + source_image: Path = Input( + description="Input source image.", + ), + driving_video: Path = Input( + description="Choose a micromotion.", + ), + dataset_name: str = Input( + choices=["vox", "taichi", "ted", "mgif"], + default="vox", + description="Choose a dataset.", + ), + ) -> Path: + + predict_mode = "relative" # ['standard', 'relative', 'avd'] + # find_best_frame = False + + pixel = 384 if dataset_name == "ted" else 256 + + if dataset_name == "vox": + # first run face alignment + align_image(str(source_image), 'aligned.png') + source_image = imageio.imread('aligned.png') + else: + source_image = imageio.imread(str(source_image)) + reader = imageio.get_reader(str(driving_video)) + fps = reader.get_meta_data()["fps"] + source_image = resize(source_image, (pixel, pixel))[..., :3] + + driving_video = [] + try: + for im in reader: + driving_video.append(im) + except RuntimeError: + pass + reader.close() + + driving_video = [ + resize(frame, (pixel, pixel))[..., :3] for frame in driving_video + ] + + inpainting, kp_detector, dense_motion_network, avd_network = ( + self.inpainting[dataset_name], + self.kp_detector[dataset_name], + self.dense_motion_network[dataset_name], + self.avd_network[dataset_name], + ) + + predictions = make_animation( + source_image, + driving_video, + inpainting, + kp_detector, + dense_motion_network, + avd_network, + device="cuda:0", + mode=predict_mode, + ) + + # save resulting video + out_path = Path(tempfile.mkdtemp()) / "output.mp4" + imageio.mimsave( + str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps + ) + return out_path + + +def align_image(raw_img_path, aligned_face_path): + for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1): + image_align(raw_img_path, aligned_face_path, face_landmarks) diff --git a/TPSMM/pretrained/vox.pth.tar b/TPSMM/pretrained/vox.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..5cb58244930878bdde245b9fce6eda59ac5fbe9b --- /dev/null +++ b/TPSMM/pretrained/vox.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52ad8c848e2a1d91b621de96fea83faf57ce3b8c1c06424e317f4df1d3998204 +size 350993469 diff --git a/TPSMM/reconstruction.py b/TPSMM/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..40d4cf466339aa87935b3d488f759a066d753a4e --- /dev/null +++ b/TPSMM/reconstruction.py @@ -0,0 +1,69 @@ +import os +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +from logger import Logger, Visualizer +import numpy as np +import imageio + + +def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset): + png_dir = os.path.join(log_dir, 'reconstruction/png') + log_dir = os.path.join(log_dir, 'reconstruction') + + if checkpoint is not None: + Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector, + bg_predictor=bg_predictor, dense_motion_network=dense_motion_network) + else: + raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") + dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + if not os.path.exists(png_dir): + os.makedirs(png_dir) + + loss_list = [] + + inpainting_network.eval() + kp_detector.eval() + dense_motion_network.eval() + if bg_predictor: + bg_predictor.eval() + + for it, x in tqdm(enumerate(dataloader)): + with torch.no_grad(): + predictions = [] + visualizations = [] + if torch.cuda.is_available(): + x['video'] = x['video'].cuda() + kp_source = kp_detector(x['video'][:, :, 0]) + for frame_idx in range(x['video'].shape[2]): + source = x['video'][:, :, 0] + driving = x['video'][:, :, frame_idx] + kp_driving = kp_detector(driving) + bg_params = None + if bg_predictor: + bg_params = bg_predictor(source, driving) + + dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving, + kp_source=kp_source, bg_param = bg_params, + dropout_flag = False) + out = inpainting_network(source, dense_motion) + out['kp_source'] = kp_source + out['kp_driving'] = kp_driving + + predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) + + visualization = Visualizer(**config['visualizer_params']).visualize(source=source, + driving=driving, out=out) + visualizations.append(visualization) + loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy() + + loss_list.append(loss) + # print(np.mean(loss_list)) + predictions = np.concatenate(predictions, axis=1) + imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) + + print("Reconstruction loss: %s" % np.mean(loss_list)) diff --git a/TPSMM/requirements.txt b/TPSMM/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..20beb09852a9ae07b1d6cae6567b2b61dcf33add --- /dev/null +++ b/TPSMM/requirements.txt @@ -0,0 +1,25 @@ +cffi==1.14.6 +cycler==0.10.0 +decorator==5.1.0 +face-alignment==1.3.5 +imageio==2.9.0 +imageio-ffmpeg==0.4.5 +kiwisolver==1.3.2 +matplotlib==3.4.3 +networkx==2.6.3 +numpy==1.20.3 +pandas==1.3.3 +Pillow==8.3.2 +pycparser==2.20 +pyparsing==2.4.7 +python-dateutil==2.8.2 +pytz==2021.1 +PyWavelets==1.1.1 +PyYAML==5.4.1 +scikit-image==0.18.3 +scikit-learn==1.0 +scipy==1.7.1 +six==1.16.0 +torch==1.10.0+cu113 +torchvision==0.11.0+cu113 +tqdm==4.62.3 \ No newline at end of file diff --git a/TPSMM/run.py b/TPSMM/run.py new file mode 100644 index 0000000000000000000000000000000000000000..6120213fe79c670212b2fc79e0ddb105fb178c45 --- /dev/null +++ b/TPSMM/run.py @@ -0,0 +1,89 @@ +import matplotlib +matplotlib.use('Agg') + +import os, sys +import yaml +from argparse import ArgumentParser +from time import gmtime, strftime +from shutil import copy +from frames_dataset import FramesDataset + +from modules.inpainting_network import InpaintingNetwork +from modules.keypoint_detector import KPDetector +from modules.bg_motion_predictor import BGMotionPredictor +from modules.dense_motion import DenseMotionNetwork +from modules.avd_network import AVDNetwork +import torch +from train import train +from train_avd import train_avd +from reconstruction import reconstruction +import os + + +if __name__ == "__main__": + + if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9") + + parser = ArgumentParser() + parser.add_argument("--config", default="config/vox-256.yaml", help="path to config") + parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"]) + parser.add_argument("--log_dir", default='log', help="path to log into") + parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") + parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))), + help="Names of the devices comma separated.") + + opt = parser.parse_args() + with open(opt.config) as f: + config = yaml.load(f) + + if opt.checkpoint is not None: + log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) + else: + log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) + log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) + + inpainting = InpaintingNetwork(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + + if torch.cuda.is_available(): + cuda_device = torch.device('cuda:'+str(opt.device_ids[0])) + inpainting.to(cuda_device) + + kp_detector = KPDetector(**config['model_params']['common_params']) + dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], + **config['model_params']['dense_motion_params']) + + if torch.cuda.is_available(): + kp_detector.to(opt.device_ids[0]) + dense_motion_network.to(opt.device_ids[0]) + + bg_predictor = None + if (config['model_params']['common_params']['bg']): + bg_predictor = BGMotionPredictor() + if torch.cuda.is_available(): + bg_predictor.to(opt.device_ids[0]) + + avd_network = None + if opt.mode == "train_avd": + avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], + **config['model_params']['avd_network_params']) + if torch.cuda.is_available(): + avd_network.to(opt.device_ids[0]) + + dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params']) + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): + copy(opt.config, log_dir) + + if opt.mode == 'train': + print("Training...") + train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) + elif opt.mode == 'train_avd': + print("Training Animation via Disentaglement...") + train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset) + elif opt.mode == 'reconstruction': + print("Reconstruction...") + reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) diff --git a/TPSMM/tmp.jpg b/TPSMM/tmp.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd953a5825c6a18bd2b1d6d6ace77536ee0cc894 Binary files /dev/null and b/TPSMM/tmp.jpg differ diff --git a/TPSMM/tmp.py b/TPSMM/tmp.py new file mode 100644 index 0000000000000000000000000000000000000000..6c91bbf096307c9b0ba4168ed65b79af70adc0bc --- /dev/null +++ b/TPSMM/tmp.py @@ -0,0 +1,14 @@ +import cv2 + + +cap = cv2.VideoCapture("/research/GAN/git/CVPR2022-DaGAN/assets/video1.mp4") +while True: + ret, frame = cap.read() + if frame is None: + break + cv2.imshow("output", frame) + key = cv2.waitKey(1) & 0xff + if key == ord("q"): + break + +cv2.destroyAllWindows() \ No newline at end of file diff --git a/TPSMM/train.py b/TPSMM/train.py new file mode 100644 index 0000000000000000000000000000000000000000..06ce3be20bc4fcbc5395c596b042c1bf2bdad8b8 --- /dev/null +++ b/TPSMM/train.py @@ -0,0 +1,94 @@ +from tqdm import trange +import torch +from torch.utils.data import DataLoader +from logger import Logger +from modules.model import GeneratorFullModel +from torch.optim.lr_scheduler import MultiStepLR +from torch.nn.utils import clip_grad_norm_ +from frames_dataset import DatasetRepeater +import math + +def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset): + train_params = config['train_params'] + optimizer = torch.optim.Adam( + [{'params': list(inpainting_network.parameters()) + + list(dense_motion_network.parameters()) + + list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4) + + optimizer_bg_predictor = None + if bg_predictor: + optimizer_bg_predictor = torch.optim.Adam( + [{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}], + lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4) + + if checkpoint is not None: + start_epoch = Logger.load_cpk( + checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network, + kp_detector = kp_detector, bg_predictor = bg_predictor, + optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor) + print('load success:', start_epoch) + start_epoch += 1 + else: + start_epoch = 0 + + scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1, + last_epoch=start_epoch - 1) + if bg_predictor: + scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'], + gamma=0.1, last_epoch=start_epoch - 1) + + if 'num_repeats' in train_params or train_params['num_repeats'] != 1: + dataset = DatasetRepeater(dataset, train_params['num_repeats']) + dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, + num_workers=train_params['dataloader_workers'], drop_last=True) + + generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params) + + if torch.cuda.is_available(): + generator_full = torch.nn.DataParallel(generator_full).cuda() + + bg_start = train_params['bg_start'] + + with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], + checkpoint_freq=train_params['checkpoint_freq']) as logger: + for epoch in trange(start_epoch, train_params['num_epochs']): + for x in dataloader: + if(torch.cuda.is_available()): + x['driving'] = x['driving'].cuda() + x['source'] = x['source'].cuda() + + losses_generator, generated = generator_full(x, epoch) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + loss.backward() + + clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf) + clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf) + if bg_predictor and epoch>=bg_start: + clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf) + + optimizer.step() + optimizer.zero_grad() + if bg_predictor and epoch>=bg_start: + optimizer_bg_predictor.step() + optimizer_bg_predictor.zero_grad() + + losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} + logger.log_iter(losses=losses) + + scheduler_optimizer.step() + if bg_predictor: + scheduler_bg_predictor.step() + + model_save = { + 'inpainting_network': inpainting_network, + 'dense_motion_network': dense_motion_network, + 'kp_detector': kp_detector, + 'optimizer': optimizer, + } + if bg_predictor and epoch>=bg_start: + model_save['bg_predictor'] = bg_predictor + model_save['optimizer_bg_predictor'] = optimizer_bg_predictor + + logger.log_epoch(epoch, model_save, inp=x, out=generated) + diff --git a/TPSMM/train_avd.py b/TPSMM/train_avd.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6794c322e01d980acc2f3b3aab2b192576900d --- /dev/null +++ b/TPSMM/train_avd.py @@ -0,0 +1,91 @@ +from tqdm import trange +import torch +from torch.utils.data import DataLoader +from logger import Logger +from torch.optim.lr_scheduler import MultiStepLR +from frames_dataset import DatasetRepeater + + +def random_scale(kp_params, scale): + theta = torch.rand(kp_params['fg_kp'].shape[0], 2) * (2 * scale) + (1 - scale) + theta = torch.diag_embed(theta).unsqueeze(1).type(kp_params['fg_kp'].type()) + new_kp_params = {'fg_kp': torch.matmul(theta, kp_params['fg_kp'].unsqueeze(-1)).squeeze(-1)} + return new_kp_params + + +def train_avd(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, + avd_network, checkpoint, log_dir, dataset): + train_params = config['train_avd_params'] + + optimizer = torch.optim.Adam(avd_network.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) + + if checkpoint is not None: + Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector, + bg_predictor=bg_predictor, avd_network=avd_network, + dense_motion_network= dense_motion_network,optimizer_avd=optimizer) + start_epoch = 0 + else: + raise AttributeError("Checkpoint should be specified for mode='train_avd'.") + + scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1) + + if 'num_repeats' in train_params or train_params['num_repeats'] != 1: + dataset = DatasetRepeater(dataset, train_params['num_repeats']) + + dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, + num_workers=train_params['dataloader_workers'], drop_last=True) + + with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], + checkpoint_freq=train_params['checkpoint_freq']) as logger: + for epoch in trange(start_epoch, train_params['num_epochs']): + avd_network.train() + for x in dataloader: + with torch.no_grad(): + kp_source = kp_detector(x['source'].cuda()) + kp_driving_gt = kp_detector(x['driving'].cuda()) + kp_driving_random = random_scale(kp_driving_gt, scale=train_params['random_scale']) + rec = avd_network(kp_source, kp_driving_random) + + reconstruction_kp = train_params['lambda_shift'] * \ + torch.abs(kp_driving_gt['fg_kp'] - rec['fg_kp']).mean() + + loss_dict = {'rec_kp': reconstruction_kp} + loss = reconstruction_kp + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + losses = {key: value.mean().detach().data.cpu().numpy() for key, value in loss_dict.items()} + logger.log_iter(losses=losses) + + # Visualization + avd_network.eval() + with torch.no_grad(): + source = x['source'][:6].cuda() + driving = torch.cat([x['driving'][[0, 1]].cuda(), source[[2, 3, 2, 1]]], dim=0) + kp_source = kp_detector(source) + kp_driving = kp_detector(driving) + + out = avd_network(kp_source, kp_driving) + kp_driving = out + dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving, + kp_source=kp_source) + generated = inpainting_network(source, dense_motion) + + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + scheduler.step(epoch) + model_save = { + 'inpainting_network': inpainting_network, + 'dense_motion_network': dense_motion_network, + 'kp_detector': kp_detector, + 'avd_network': avd_network, + 'optimizer_avd': optimizer + } + if bg_predictor : + model_save['bg_predictor'] = bg_predictor + + logger.log_epoch(epoch, model_save, + inp={'source': source, 'driving': driving}, + out=generated) diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..74ab3f624f38c957738523167a9d414322e98f1b --- /dev/null +++ b/app.py @@ -0,0 +1,122 @@ +import av +import sys +import numpy as np +import cv2 +import streamlit as st +from PIL import Image +from streamlit_webrtc import WebRtcMode, webrtc_streamer + +sys.path.insert(1, "./retinaface") +sys.path.insert(1, "./TPSMM/pkgs") +from tpsmm import TPSMM +from detect import Detect +from turn import get_ice_servers + + +def parse_roi_box_from_bbox(bbox, shape): + img_h, img_w = shape[:2] + left, top, right, bottom = bbox[:4] + old_size = (right - left + bottom - top) / 2 + center_x = right - (right - left) / 2.0 + center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 + + size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0) + + roi_box = [0] * 4 + roi_box[0] = center_x - size / 2 + roi_box[1] = center_y - size / 2 + roi_box[2] = roi_box[0] + size + roi_box[3] = roi_box[1] + size + + return roi_box + +cache_key = "retinaface" +if cache_key in st.session_state: + detector = st.session_state[cache_key] +else: + detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864)) + st.session_state[cache_key] = detector + +cache_key = "tpsmm" +if cache_key in st.session_state: + generator = st.session_state[cache_key] +else: + generator = TPSMM() + st.session_state[cache_key] = generator + + +@st.cache_resource # type: ignore +def get_images(): + images = [ + cv2.imread("assets/0.jpg"), + cv2.imread("assets/1.jpg"), + cv2.imread("assets/2.jpg"), + cv2.imread("assets/3.jpg"), + ] + item_list = [str(i) for i in range(len(images))] + images = [generator.process_source(src_img) for src_img in images] + + return dict(zip(item_list, images)) +images = get_images() +user_option = st.selectbox("Choose an item", list(images.keys())) + +uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg']) +@st.cache_resource +def process_file(uploaded_file): + img = Image.open(uploaded_file) + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + dets = detector(img) + for i, b in enumerate(dets): + bbox = parse_roi_box_from_bbox(b[:4], img.shape) + bbox = [int(i) for i in bbox] + + face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() + # cv2.imwrite("./tmp.jpg", face_img) + return generator.process_source(face_img) + + return None +if uploaded_file is not None: + uploaded_file = process_file(uploaded_file) + +def callback(frame: av.VideoFrame) -> av.VideoFrame: + img = frame.to_ndarray(format="bgr24") + + try: + dets = detector(img) + output = None + for i, b in enumerate(dets): + text = "{:.4f}".format(b[4]) + b = b.astype(np.int32) + cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) + bbox = parse_roi_box_from_bbox(b[:4], img.shape) + bbox = [int(i) for i in bbox] + cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2) + + face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() + if uploaded_file is None: + source_tensor, kp_source = images[user_option] + else: + source_tensor, kp_source = uploaded_file + output = generator.gen_image(face_img, source_tensor, kp_source) + + landm = b[5:15] + landm = landm.reshape((5, 2)) + cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2) + cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2) + cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2) + cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2) + cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2) + + if output is not None: + img[:256, :256] = output + except Exception as e: + print(e) + + return av.VideoFrame.from_ndarray(img, format="bgr24") + +webrtc_streamer( + key="sample", + rtc_configuration={"iceServers": get_ice_servers()}, + video_frame_callback=callback, + media_stream_constraints={"video": True, "audio": False}, +) \ No newline at end of file diff --git a/assets/0.jpg b/assets/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..852cb6d7ab304bcac7f18e1f0fb7d1a229e7f502 Binary files /dev/null and b/assets/0.jpg differ diff --git a/assets/1.jpg b/assets/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20a6e64c8703e54b943fb64a48f63ec90cbb9df1 Binary files /dev/null and b/assets/1.jpg differ diff --git a/assets/2.jpg b/assets/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9132972dceff4c0d9fd42421b07e3c4d36a94c97 Binary files /dev/null and b/assets/2.jpg differ diff --git a/assets/3.jpg b/assets/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8636aeea387fe8c32f375d21ab12ad02753c707a Binary files /dev/null and b/assets/3.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c6537c213c1503c83a083c54b142c4eb0ea3ce5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +streamlit-webrtc +twilio +altair<5 +numpy==1.23.1 +opencv-python==4.8.0.74 +imutils +scikit-image==0.21.0 +matplotlib==3.7.1 +pyaml==23.5.9 +tqdm +torch +torchvision \ No newline at end of file diff --git a/retinaface/change_batch_onnx.py b/retinaface/change_batch_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..68cfc02981ea068eaf034ac32c4cb20e9214f5d4 --- /dev/null +++ b/retinaface/change_batch_onnx.py @@ -0,0 +1,43 @@ +import onnx + + +model = onnx.load('weights/faceDetector_243_432_b1_sim.onnx') + +# # for fixed batchsize +# model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 32 +# model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 32 +# model.graph.output[1].type.tensor_type.shape.dim[0].dim_value = 32 +# model.graph.output[2].type.tensor_type.shape.dim[0].dim_value = 32 +# onnx.save(model, 'faceDetector_640_b32.onnx') + + +model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'batch' # for dynamic batchsize +model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[1].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[2].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[3].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[4].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[5].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[6].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[7].type.tensor_type.shape.dim[0].dim_param = 'batch' +model.graph.output[8].type.tensor_type.shape.dim[0].dim_param = 'batch' +onnx.save(model, 'weights/faceDetector_243_432_batch_sim.onnx') + + +#################################################### +# SHow model onnx + +# import onnxruntime as rt +# ort_session = rt.InferenceSession("faceDetector_180_320_batch_sim.onnx") +# print("====INPUT====") +# for i in ort_session.get_inputs(): +# print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type)) +# print("====OUTPUT====") +# for i in ort_session.get_outputs(): +# print("Name: {}, Shape: {}, Dtype: {}".format(i.name, i.shape, i.type)) + +# import numpy as np +# input_name = ort_session.get_inputs()[0].name +# img = np.random.randn(4, 3, 180, 320).astype(np.float32) +# data = ort_session.run(None, {input_name: img}) +# print("Done") \ No newline at end of file diff --git a/retinaface/convert_to_onnx.py b/retinaface/convert_to_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..da98249ca3412ad2124efea2bc909c6586152038 --- /dev/null +++ b/retinaface/convert_to_onnx.py @@ -0,0 +1,135 @@ +# python convert_to_onnx.py --network mobile0.25 --trained_model weights/mobilenet0.25_Final.pth +from __future__ import print_function +import os +import argparse +import torch +import torch.backends.cudnn as cudnn +import numpy as np +from data import cfg_mnet, cfg_slim, cfg_rfb +from layers.functions.prior_box import PriorBox +from utils.nms.py_cpu_nms import py_cpu_nms +import cv2 +from models.retinaface import RetinaFace +from models.net_slim import Slim +from models.net_rfb import RFB +from utils.box_utils import decode, decode_landm +from utils.timer import Timer + + +parser = argparse.ArgumentParser(description='Test') +parser.add_argument('-m', '--trained_model', default='./weights/RBF_Final.pth', + type=str, help='Trained state_dict file path to open') +parser.add_argument('--network', default='RFB', help='Backbone network mobile0.25 or slim or RFB') +parser.add_argument('--long_side', default=320, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)') +parser.add_argument('--cpu', action="store_true", help='Use cpu inference') + +args = parser.parse_args() + + +def check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + print('Missing keys:{}'.format(len(missing_keys))) + print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) + print('Used keys:{}'.format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + +def remove_prefix(state_dict, prefix): + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' + print('remove prefix \'{}\''.format(prefix)) + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x + return {f(key): value for key, value in state_dict.items()} + + +def load_model(model, pretrained_path, load_to_cpu): + print('Loading pretrained model from {}'.format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') + else: + pretrained_dict = remove_prefix(pretrained_dict, 'module.') + check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model + + +if __name__ == '__main__': + torch.set_grad_enabled(False) + + cfg = None + net = None + # long_side = int(args.long_side) + net_inshape = (243, 432) + device = torch.device("cpu" if args.cpu else "cuda") + print(device) + if args.network == "mobile0.25": + cfg = cfg_mnet + # net_inshape = (long_side, long_side) # h, w + priorbox = PriorBox(cfg, image_size=net_inshape) + priors = priorbox.forward() + prior_data = priors.to(device) + net = RetinaFace(cfg=cfg, phase='test') + elif args.network == "slim": + cfg = cfg_slim + net = Slim(cfg = cfg, phase = 'test') + elif args.network == "RFB": + cfg = cfg_rfb + net = RFB(cfg = cfg, phase = 'test') + else: + print("Don't support network!") + exit(0) + + # load weight + net = load_model(net, args.trained_model, args.cpu) + net.eval() + print('Finished loading model!') + print(net) + net = net.to(device) + + ##################export############### + output_onnx = f'weights/faceDetector_{net_inshape[0]}_{net_inshape[1]}_b1.onnx' + print("==> Exporting model to ONNX format at '{}'".format(output_onnx)) + input_names = ['input_1'] + output_names = ['box_1', 'box_2', 'box_3'] + + # import torch.onnx.symbolic_opset9 as onnx_symbolic + # def upsample_nearest2d(g, input, output_size, *args): + # # Currently, TRT 5.1/6.0/7.0 ONNX Parser does not support all ONNX ops + # # needed to support dynamic upsampling ONNX forumlation + # # Here we hardcode scale=2 as a temporary workaround + # scales = g.op("Constant", value_t=torch.tensor([1., 1., 2., 2.])) + # return g.op("Resize", input, scales, mode_s="nearest") + + + # onnx_symbolic.upsample_nearest2d = upsample_nearest2d + + # import io + # onnx_bytes = io.BytesIO() + # zero_input = torch.zeros([1, 3, net_inshape[0], net_inshape[1]]).cuda() + # dynamic_axes = {input_names[0]: {0:'batch'}} + # for _, name in enumerate(output_names): + # dynamic_axes[name] = dynamic_axes[input_names[0]] + # extra_args = {'opset_version': 10, 'verbose': False, + # 'input_names': input_names, 'output_names': output_names, + # 'dynamic_axes': dynamic_axes} + # torch.onnx.export(net, zero_input, onnx_bytes, **extra_args) + # with open(output_onnx, 'wb') as out: + # out.write(onnx_bytes.getvalue()) + + inputs = torch.randn(1, 3, net_inshape[0], net_inshape[1]).to(device) + torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False, opset_version=9, + input_names=input_names, output_names=output_names) + ################end############### + + + + diff --git a/retinaface/data/__init__.py b/retinaface/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea50ebaf88d64e75f4960bc99b14f138a343e575 --- /dev/null +++ b/retinaface/data/__init__.py @@ -0,0 +1,3 @@ +from .wider_face import WiderFaceDetection, detection_collate +from .data_augment import * +from .config import * diff --git a/retinaface/data/config.py b/retinaface/data/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8d22bdfc205e07b7e88590526c72421d3f54d0ec --- /dev/null +++ b/retinaface/data/config.py @@ -0,0 +1,55 @@ +# config.py +cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[10, 20], [32, 64], [128, 256]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 300, + "net_inshape": (320, 320), # h, w + 'pretrain': False, + 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, + 'in_channel': 32, + 'out_channel': 64 +} + +cfg_slim = { + 'name': 'slim', + 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]], + 'steps': [8, 16, 32, 64], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 300 +} + +cfg_rfb = { + 'name': 'RFB', + 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]], + 'steps': [8, 16, 32, 64], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 300 +} + + diff --git a/retinaface/data/data_augment.py b/retinaface/data/data_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..33ec89d8382b758d0ca26520d422c6bd730c4cb2 --- /dev/null +++ b/retinaface/data/data_augment.py @@ -0,0 +1,235 @@ +import cv2 +import numpy as np +import random +from utils.box_utils import matrix_iof + + +def _crop(image, boxes, labels, landm, img_dim): + height, width, _ = image.shape + pad_image_flag = True + + for _ in range(250): + if random.uniform(0, 1) <= 0.2: + scale = 1.0 + else: + scale = random.uniform(0.3, 1.0) + # PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] + # scale = random.choice(PRE_SCALES) + short_side = min(width, height) + w = int(scale * short_side) + h = w + + if width == w: + l = 0 + else: + l = random.randrange(width - w) + if height == h: + t = 0 + else: + t = random.randrange(height - h) + roi = np.array((l, t, l + w, t + h)) + + value = matrix_iof(boxes, roi[np.newaxis]) + flag = (value >= 1) + if not flag.any(): + continue + + centers = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) + boxes_t = boxes[mask_a].copy() + labels_t = labels[mask_a].copy() + landms_t = landm[mask_a].copy() + landms_t = landms_t.reshape([-1, 5, 2]) + + if boxes_t.shape[0] == 0: + continue + + image_t = image[roi[1]:roi[3], roi[0]:roi[2]] + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) + boxes_t[:, :2] -= roi[:2] + boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) + boxes_t[:, 2:] -= roi[:2] + + # landm + landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] + landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) + landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) + landms_t = landms_t.reshape([-1, 10]) + + + # make sure that the cropped image contains at least one face > 16 pixel at training image scale + b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim + b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim + mask_b = np.minimum(b_w_t, b_h_t) > 5 + boxes_t = boxes_t[mask_b] + labels_t = labels_t[mask_b] + landms_t = landms_t[mask_b] + + if boxes_t.shape[0] == 0: + continue + + pad_image_flag = False + + return image_t, boxes_t, labels_t, landms_t, pad_image_flag + return image, boxes, labels, landm, pad_image_flag + + +def _distort(image): + + def _convert(image, alpha=1, beta=0): + tmp = image.astype(float) * alpha + beta + tmp[tmp < 0] = 0 + tmp[tmp > 255] = 255 + image[:] = tmp + + image = image.copy() + + if random.randrange(2): + + #brightness distortion + if random.randrange(2): + _convert(image, beta=random.uniform(-32, 32)) + + #contrast distortion + if random.randrange(2): + _convert(image, alpha=random.uniform(0.5, 1.5)) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + #saturation distortion + if random.randrange(2): + _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) + + #hue distortion + if random.randrange(2): + tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) + tmp %= 180 + image[:, :, 0] = tmp + + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + else: + + #brightness distortion + if random.randrange(2): + _convert(image, beta=random.uniform(-32, 32)) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + #saturation distortion + if random.randrange(2): + _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) + + #hue distortion + if random.randrange(2): + tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) + tmp %= 180 + image[:, :, 0] = tmp + + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + #contrast distortion + if random.randrange(2): + _convert(image, alpha=random.uniform(0.5, 1.5)) + + return image + + +def _expand(image, boxes, fill, p): + if random.randrange(2): + return image, boxes + + height, width, depth = image.shape + + scale = random.uniform(1, p) + w = int(scale * width) + h = int(scale * height) + + left = random.randint(0, w - width) + top = random.randint(0, h - height) + + boxes_t = boxes.copy() + boxes_t[:, :2] += (left, top) + boxes_t[:, 2:] += (left, top) + expand_image = np.empty( + (h, w, depth), + dtype=image.dtype) + expand_image[:, :] = fill + expand_image[top:top + height, left:left + width] = image + image = expand_image + + return image, boxes_t + + +def _mirror(image, boxes, landms): + _, width, _ = image.shape + if random.randrange(2): + image = image[:, ::-1] + boxes = boxes.copy() + boxes[:, 0::2] = width - boxes[:, 2::-2] + + # landm + landms = landms.copy() + landms = landms.reshape([-1, 5, 2]) + landms[:, :, 0] = width - landms[:, :, 0] + tmp = landms[:, 1, :].copy() + landms[:, 1, :] = landms[:, 0, :] + landms[:, 0, :] = tmp + tmp1 = landms[:, 4, :].copy() + landms[:, 4, :] = landms[:, 3, :] + landms[:, 3, :] = tmp1 + landms = landms.reshape([-1, 10]) + + return image, boxes, landms + + +def _pad_to_square(image, rgb_mean, pad_image_flag): + if not pad_image_flag: + return image + height, width, _ = image.shape + long_side = max(width, height) + image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) + image_t[:, :] = rgb_mean + image_t[0:0 + height, 0:0 + width] = image + return image_t + + +def _resize_subtract_mean(image, insize, rgb_mean): + interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] + interp_method = interp_methods[random.randrange(5)] + image = cv2.resize(image, (insize, insize), interpolation=interp_method) + image = image.astype(np.float32) + image -= rgb_mean + return image.transpose(2, 0, 1) + + +class preproc(object): + + def __init__(self, img_dim, rgb_means): + self.img_dim = img_dim + self.rgb_means = rgb_means + + def __call__(self, image, targets): + assert targets.shape[0] > 0, "this image does not have gt" + + boxes = targets[:, :4].copy() + labels = targets[:, -1].copy() + landm = targets[:, 4:-1].copy() + + image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) + image_t = _distort(image_t) + image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) + image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) + height, width, _ = image_t.shape + image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) + boxes_t[:, 0::2] /= width + boxes_t[:, 1::2] /= height + + landm_t[:, 0::2] /= width + landm_t[:, 1::2] /= height + + labels_t = np.expand_dims(labels_t, 1) + targets_t = np.hstack((boxes_t, landm_t, labels_t)) + + return image_t, targets_t diff --git a/retinaface/data/wider_face.py b/retinaface/data/wider_face.py new file mode 100644 index 0000000000000000000000000000000000000000..22f56efdc221bd4162d22884669ba44a3d4de5cd --- /dev/null +++ b/retinaface/data/wider_face.py @@ -0,0 +1,101 @@ +import os +import os.path +import sys +import torch +import torch.utils.data as data +import cv2 +import numpy as np + +class WiderFaceDetection(data.Dataset): + def __init__(self, txt_path, preproc=None): + self.preproc = preproc + self.imgs_path = [] + self.words = [] + f = open(txt_path,'r') + lines = f.readlines() + isFirst = True + labels = [] + for line in lines: + line = line.rstrip() + if line.startswith('#'): + if isFirst is True: + isFirst = False + else: + labels_copy = labels.copy() + self.words.append(labels_copy) + labels.clear() + path = line[2:] + path = txt_path.replace('label.txt','images/') + path + self.imgs_path.append(path) + else: + line = line.split(' ') + label = [float(x) for x in line] + labels.append(label) + + self.words.append(labels) + + def __len__(self): + return len(self.imgs_path) + + def __getitem__(self, index): + img = cv2.imread(self.imgs_path[index]) + height, width, _ = img.shape + + labels = self.words[index] + annotations = np.zeros((0, 15)) + if len(labels) == 0: + return annotations + for idx, label in enumerate(labels): + annotation = np.zeros((1, 15)) + # bbox + annotation[0, 0] = label[0] # x1 + annotation[0, 1] = label[1] # y1 + annotation[0, 2] = label[0] + label[2] # x2 + annotation[0, 3] = label[1] + label[3] # y2 + + # landmarks + annotation[0, 4] = label[4] # l0_x + annotation[0, 5] = label[5] # l0_y + annotation[0, 6] = label[7] # l1_x + annotation[0, 7] = label[8] # l1_y + annotation[0, 8] = label[10] # l2_x + annotation[0, 9] = label[11] # l2_y + annotation[0, 10] = label[13] # l3_x + annotation[0, 11] = label[14] # l3_y + annotation[0, 12] = label[16] # l4_x + annotation[0, 13] = label[17] # l4_y + if (annotation[0, 4]<0): + annotation[0, 14] = -1 + else: + annotation[0, 14] = 1 + + annotations = np.append(annotations, annotation, axis=0) + target = np.array(annotations) + if self.preproc is not None: + img, target = self.preproc(img, target) + + return torch.from_numpy(img), target + +def detection_collate(batch): + """Custom collate fn for dealing with batches of images that have a different + number of associated object annotations (bounding boxes). + + Arguments: + batch: (tuple) A tuple of tensor images and lists of annotations + + Return: + A tuple containing: + 1) (tensor) batch of images stacked on their 0 dim + 2) (list of tensors) annotations for a given image are stacked on 0 dim + """ + targets = [] + imgs = [] + for _, sample in enumerate(batch): + for _, tup in enumerate(sample): + if torch.is_tensor(tup): + imgs.append(tup) + elif isinstance(tup, type(np.empty(0))): + annos = torch.from_numpy(tup).float() + targets.append(annos) + + return (torch.stack(imgs, 0), targets) diff --git a/retinaface/detect.py b/retinaface/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..dc622e4deef728d1153254bfa6cf0d4ef50352aa --- /dev/null +++ b/retinaface/detect.py @@ -0,0 +1,152 @@ +import torch +import torch.nn.functional as F +import cv2 +import numpy as np +import timeit + +import imutils +from utils.infer_utils import load_model +from data import cfg_mnet as cfg +from models.retinaface import RetinaFace +from layers.functions.prior_box import PriorBox +from utils.box_utils import decode, decode_landm +from utils.nms.py_cpu_nms import py_cpu_nms +from utils.infer_utils import align_face +torch.set_grad_enabled(False) + + +class Detect: + def __init__(self, weight_path, net_inshape=(180, 320)): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.net_inshape = net_inshape + im_height, im_width = net_inshape + self.box_scale = np.array([im_width, im_height] * 2) + self.lmk_scale = np.array([im_width, im_height] * 5) + + priorbox = PriorBox(cfg, image_size=net_inshape) + priors = priorbox.forward() + self.prior_data = priors.to(self.device) + self.net = RetinaFace(cfg=cfg, phase='test') + self.net = load_model(self.net, weight_path, False) + self.net.eval() + self.net = self.net.to(self.device) + + + def _preprocess(self, image): + rgb_mean = (104, 117, 123) # bgr order + h, w = image.shape[:2] + dx = int(self.net_inshape[1] * h / self.net_inshape[0] - w) + dy = 0 + if dx < 0: + dx = 0 + dy = int(self.net_inshape[0] * w / self.net_inshape[1] - h) + img = cv2.copyMakeBorder(image, 0, dy, 0, dx, borderType=cv2.BORDER_CONSTANT, value=rgb_mean) + img = cv2.copyMakeBorder(img, 0, img.shape[0], 0, img.shape[1], borderType=cv2.BORDER_CONSTANT, value=rgb_mean) + + h, w = img.shape[:2] + resize = float(self.net_inshape[1]) / float(w) + img = cv2.resize(img, self.net_inshape[::-1]) + img = np.float32(img) + img -= rgb_mean + + return img, resize + + + def __call__(self, img, verbose=False): + ''' + bgr image + ''' + t0 = timeit.default_timer() + img, resize = self._preprocess(img) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + + t1 = timeit.default_timer() + loc, conf, landms = self.net(img) + loc = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 4) for i in loc] + loc = torch.cat(loc, dim=1) + conf = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 3) for i in conf] + conf = torch.cat(conf, dim=1) + landms = [i.permute(0,2,3,1).contiguous().view(i.shape[0], -1, 10) for i in landms] + landms = torch.cat(landms, dim=1) + conf = F.softmax(conf, dim=-1) + + t2 = timeit.default_timer() + conf = conf[0] + scores = conf.squeeze(0).detach().cpu().numpy()[:, 1:] + scores = np.amax(scores, axis=1) + + boxes = decode(loc[0], self.prior_data, cfg['variance']) # loc[0] + boxes = boxes.detach().cpu().numpy() + boxes = boxes * self.box_scale / resize + + landms = decode_landm(landms[0], self.prior_data, cfg['variance']) + landms = landms.detach().cpu().numpy() + landms = landms * self.lmk_scale / resize + + # ignore low scores + inds = np.where(scores > 0.02)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:5000] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, 0.4) + # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:750, :] + landms = landms[:750, :] + + dets = np.concatenate((dets, landms), axis=1) + dets = dets[dets[:, 4] > 0.5] + dets = dets[np.argsort(dets, axis=0)[:, 0]] + + t3 = timeit.default_timer() + if verbose: + print(t1 - t0, t2 - t1, t3 - t2) + + return dets # (n, 15), box=0-3, cls=4, lmk=5-10 + + +if __name__ == "__main__": + net_inshape = (486, 864) # h, w + model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape) + image_path = "/mnt/nvme0n1p2/datasets/face/dyno/mytelpay230626/mytelpay230626_raw/data_2nd/၁၀ကတန(နိုင်)၀၀၁၀၀၁/mytel_ekyc_1m2_65160cc1802b6183d87fca091cab4c2faa93a9b1614106b5911ca778_front_image.jpg" + img = cv2.imread(image_path) + + dets = model(img) + for i, b in enumerate(dets): + text = "{:.4f}".format(b[4]) + b = b.astype(np.int32) + landm = b[5:15] + landm = landm.reshape((5, 2)) + + alighed_face = align_face(img, landm.copy()) + # cv2.imshow(str(i), alighed_face) + + # landms + landm = landm.astype(np.int32) + cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2) + cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2) + cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2) + cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2) + cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2) + + cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) + cx = b[0] + cy = b[1] + 20 + cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255)) + + cv2.imwrite("./output.jpg", img) + \ No newline at end of file diff --git a/retinaface/detect_video_raw.py b/retinaface/detect_video_raw.py new file mode 100644 index 0000000000000000000000000000000000000000..231a7932b639655d2b38ed3f8f814e249d993c07 --- /dev/null +++ b/retinaface/detect_video_raw.py @@ -0,0 +1,66 @@ +import cv2 +import numpy as np +import imutils + +from utils.fps import FPS +from utils.infer_utils import LoadStream, align_face +from detect import Detect + + +net_inshape = (486, 864) # h, w +model = Detect("/mnt/nvme0n1p2/ExternalHardrive/research/object_detection/face/Face-Detector-1MB-with-landmark-clear/weights/mobilenet0.25_epoch_842.pth", net_inshape=net_inshape) +# dataloader = LoadStream("rtsp://admin:meditech123@192.168.100.90:555/") +dataloader = LoadStream("../30Shine_1.mp4") +fps = FPS().start() + +for frame in dataloader: + # frame = imutils.resize(frame, width=640) + frame = frame.copy() + frame_raw = frame.copy() + + + dets = model(frame) + for i, b in enumerate(dets): + text = "{:.4f}".format(b[4]) + b = b.astype(np.int32) + landm = b[5:15] + landm = landm.reshape((5, 2)) + + alighed_face = align_face(frame, landm.copy()) + # cv2.imshow(str(i), alighed_face) + + # landms + landm = landm.astype(np.int32) + cv2.circle(frame, tuple(landm[0]), 1, (0, 0, 255), 2) + cv2.circle(frame, tuple(landm[1]), 1, (0, 255, 255), 2) + cv2.circle(frame, tuple(landm[2]), 1, (255, 0, 255), 2) + cv2.circle(frame, tuple(landm[3]), 1, (0, 255, 0), 2) + cv2.circle(frame, tuple(landm[4]), 1, (255, 0, 0), 2) + + cv2.rectangle(frame, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) + cx = b[0] + cy = b[1] + 20 + cv2.putText(frame, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 1.3, (255, 255, 255)) + + fps.update() + text_fps = "FPS: {:.3f}".format(fps.get_fps_n()) + cv2.putText(frame, text_fps, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) + cv2.imshow("frame", imutils.resize(frame, width=1700)) + key = cv2.waitKey(1) & 0xff + if key == ord("q"): + break + elif key == ord("c"): + while True: + cv2.imshow("frame", imutils.resize(frame, width=1700)) + key = cv2.waitKey(1) & 0xff + if key == ord("q"): + break + # cv2.imwrite(f"{i}.jpg", alighed_face) + # i += 1 + # # break + +print(text_fps) +cv2.destroyAllWindows() +fps.stop() +print("Total FPS: {}".format(fps.fps())) +dataloader.close() \ No newline at end of file diff --git a/retinaface/layers/__init__.py b/retinaface/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53a3f4b5160995d93bc7911e808b3045d74362c9 --- /dev/null +++ b/retinaface/layers/__init__.py @@ -0,0 +1,2 @@ +from .functions import * +from .modules import * diff --git a/retinaface/layers/functions/prior_box.py b/retinaface/layers/functions/prior_box.py new file mode 100644 index 0000000000000000000000000000000000000000..683156a298d6f9440be64c611988b04aea97ec49 --- /dev/null +++ b/retinaface/layers/functions/prior_box.py @@ -0,0 +1,33 @@ +import torch +from itertools import product as product +import numpy as np +from math import ceil + + +class PriorBox(object): + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output diff --git a/retinaface/layers/modules/__init__.py b/retinaface/layers/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf24bddbf283f233d0b93fc074a2bac2f5c044a9 --- /dev/null +++ b/retinaface/layers/modules/__init__.py @@ -0,0 +1,3 @@ +from .multibox_loss import MultiBoxLoss + +__all__ = ['MultiBoxLoss'] diff --git a/retinaface/layers/modules/multibox_loss.py b/retinaface/layers/modules/multibox_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..096620480eba59e9d893c1940899f7e3d6736cae --- /dev/null +++ b/retinaface/layers/modules/multibox_loss.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from utils.box_utils import match, log_sum_exp +from data import cfg_mnet +GPU = cfg_mnet['gpu_train'] + +class MultiBoxLoss(nn.Module): + """SSD Weighted Loss Function + Compute Targets: + 1) Produce Confidence Target Indices by matching ground truth boxes + with (default) 'priorboxes' that have jaccard index > threshold parameter + (default threshold: 0.5). + 2) Produce localization target by 'encoding' variance into offsets of ground + truth boxes and their matched 'priorboxes'. + 3) Hard negative mining to filter the excessive number of negative examples + that comes with using a large number of default bounding boxes. + (default negative:positive ratio 3:1) + Objective Loss: + L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss + weighted by α which is set to 1 by cross val. + Args: + c: class confidences, + l: predicted boxes, + g: ground truth boxes + N: number of matched default boxes + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + """ + + def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): + super(MultiBoxLoss, self).__init__() + self.num_classes = num_classes + self.threshold = overlap_thresh + self.background_label = bkg_label + self.encode_target = encode_target + self.use_prior_for_matching = prior_for_matching + self.do_neg_mining = neg_mining + self.negpos_ratio = neg_pos + self.neg_overlap = neg_overlap + self.variance = [0.1, 0.2] + + def forward(self, predictions, priors, targets): + """Multibox Loss + Args: + predictions (tuple): A tuple containing loc preds, conf preds, + and prior boxes from SSD net. + conf shape: torch.size(batch_size,num_priors,num_classes) + loc shape: torch.size(batch_size,num_priors,4) + priors shape: torch.size(num_priors,4) + + ground_truth (tensor): Ground truth boxes and labels for a batch, + shape: [batch_size,num_objs,5] (last idx is the label). + """ + + loc_data, conf_data, landm_data = predictions + priors = priors + num = loc_data.size(0) + num_priors = (priors.size(0)) + + # match priors (default boxes) and ground truth boxes + loc_t = torch.Tensor(num, num_priors, 4) + landm_t = torch.Tensor(num, num_priors, 10) + conf_t = torch.LongTensor(num, num_priors) + for idx in range(num): + truths = targets[idx][:, :4].data + labels = targets[idx][:, -1].data + landms = targets[idx][:, 4:14].data + defaults = priors.data + match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) + if GPU: + loc_t = loc_t.cuda() + conf_t = conf_t.cuda() + landm_t = landm_t.cuda() + + zeros = torch.tensor(0).cuda() + # landm Loss (Smooth L1) + # Shape: [batch,num_priors,10] + pos1 = conf_t > zeros + num_pos_landm = pos1.long().sum(1, keepdim=True) + N1 = max(num_pos_landm.data.sum().float(), 1) + pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) + landm_p = landm_data[pos_idx1].view(-1, 10) + landm_t = landm_t[pos_idx1].view(-1, 10) + loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') + + + pos = conf_t != zeros + conf_t[pos] = 1 + + # Localization Loss (Smooth L1) + # Shape: [batch,num_priors,4] + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) + loc_p = loc_data[pos_idx].view(-1, 4) + loc_t = loc_t[pos_idx].view(-1, 4) + loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') + + # Compute max conf across batch for hard negative mining + batch_conf = conf_data.view(-1, self.num_classes) + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) + + # Hard Negative Mining + loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now + loss_c = loss_c.view(num, -1) + _, loss_idx = loss_c.sort(1, descending=True) + _, idx_rank = loss_idx.sort(1) + num_pos = pos.long().sum(1, keepdim=True) + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) + neg = idx_rank < num_neg.expand_as(idx_rank) + + # Confidence Loss Including Positive and Negative Examples + pos_idx = pos.unsqueeze(2).expand_as(conf_data) + neg_idx = neg.unsqueeze(2).expand_as(conf_data) + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) + targets_weighted = conf_t[(pos+neg).gt(0)] + loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') + + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + N = max(num_pos.data.sum().float(), 1) + loss_l /= N + loss_c /= N + loss_landm /= N1 + + return loss_l, loss_c, loss_landm diff --git a/retinaface/models/__init__.py b/retinaface/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/retinaface/models/net.py b/retinaface/models/net.py new file mode 100644 index 0000000000000000000000000000000000000000..7d769cc2f0a3d3e85171d79ae3ea9396606b53d9 --- /dev/null +++ b/retinaface/models/net.py @@ -0,0 +1,137 @@ +import time +import torch +import torch.nn as nn +import torchvision.models._utils as _utils +import torchvision.models as models +import torch.nn.functional as F +from torch.autograd import Variable + +def conv_bn(inp, oup, stride = 1): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + +def conv_bn1X1(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +class SSH(nn.Module): + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1) + self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1) + self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + +class FPN(nn.Module): + def __init__(self,in_channels_list,out_channels): + super(FPN,self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1) + + self.merge1 = conv_bn(out_channels, out_channels) + self.merge2 = conv_bn(out_channels, out_channels) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + + +class MobileNetV1(nn.Module): + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1,1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + diff --git a/retinaface/models/net_rfb.py b/retinaface/models/net_rfb.py new file mode 100644 index 0000000000000000000000000000000000000000..0b79e959c1133a123e4a343b27ce406c859ebf4b --- /dev/null +++ b/retinaface/models/net_rfb.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +import torchvision.models.detection.backbone_utils as backbone_utils +import torchvision.models._utils as _utils +import torch.nn.functional as F +from collections import OrderedDict + +class BasicConv(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True): + super(BasicConv, self).__init__() + self.out_channels = out_planes + if bn: + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) + self.relu = nn.ReLU(inplace=True) if relu else None + else: + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True) + self.bn = None + self.relu = nn.ReLU(inplace=True) if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class BasicRFB(nn.Module): + + def __init__(self, in_planes, out_planes, stride=1, scale=0.1, map_reduce=8, vision=1, groups=1): + super(BasicRFB, self).__init__() + self.scale = scale + self.out_channels = out_planes + inter_planes = in_planes // map_reduce + + self.branch0 = nn.Sequential( + BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False), + BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups), + BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 1, dilation=vision + 1, relu=False, groups=groups) + ) + self.branch1 = nn.Sequential( + BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False), + BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups), + BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 2, dilation=vision + 2, relu=False, groups=groups) + ) + self.branch2 = nn.Sequential( + BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False), + BasicConv(inter_planes, (inter_planes // 2) * 3, kernel_size=3, stride=1, padding=1, groups=groups), + BasicConv((inter_planes // 2) * 3, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups), + BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 4, dilation=vision + 4, relu=False, groups=groups) + ) + + self.ConvLinear = BasicConv(6 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False) + self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + + out = torch.cat((x0, x1, x2), 1) + out = self.ConvLinear(out) + short = self.shortcut(x) + out = out * self.scale + short + out = self.relu(out) + + return out + + + +def conv_bn(inp, oup, stride = 1): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +def depth_conv2d(inp, oup, kernel=1, stride=1, pad=0): + return nn.Sequential( + nn.Conv2d(inp, inp, kernel_size = kernel, stride = stride, padding=pad, groups=inp), + nn.ReLU(inplace=True), + nn.Conv2d(inp, oup, kernel_size=1) + ) + +def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +class RFB(nn.Module): + def __init__(self, cfg = None, phase = 'train'): + """ + :param cfg: Network related settings. + :param phase: train or test. + """ + super(RFB,self).__init__() + self.phase = phase + self.num_classes = 2 + + self.conv1 = conv_bn(3, 16, 2) + self.conv2 = conv_dw(16, 32, 1) + self.conv3 = conv_dw(32, 32, 2) + self.conv4 = conv_dw(32, 32, 1) + self.conv5 = conv_dw(32, 64, 2) + self.conv6 = conv_dw(64, 64, 1) + self.conv7 = conv_dw(64, 64, 1) + self.conv8 = BasicRFB(64, 64, stride=1, scale=1.0) + + self.conv9 = conv_dw(64, 128, 2) + self.conv10 = conv_dw(128, 128, 1) + self.conv11 = conv_dw(128, 128, 1) + + self.conv12 = conv_dw(128, 256, 2) + self.conv13 = conv_dw(256, 256, 1) + + self.conv14 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1), + nn.ReLU(inplace=True), + depth_conv2d(64, 256, kernel=3, stride=2, pad=1), + nn.ReLU(inplace=True) + ) + self.loc, self.conf, self.landm = self.multibox(self.num_classes); + + def multibox(self, num_classes): + loc_layers = [] + conf_layers = [] + landm_layers = [] + loc_layers += [depth_conv2d(64, 3 * 4, kernel=3, pad=1)] + conf_layers += [depth_conv2d(64, 3 * num_classes, kernel=3, pad=1)] + landm_layers += [depth_conv2d(64, 3 * 10, kernel=3, pad=1)] + + loc_layers += [depth_conv2d(128, 2 * 4, kernel=3, pad=1)] + conf_layers += [depth_conv2d(128, 2 * num_classes, kernel=3, pad=1)] + landm_layers += [depth_conv2d(128, 2 * 10, kernel=3, pad=1)] + + loc_layers += [depth_conv2d(256, 2 * 4, kernel=3, pad=1)] + conf_layers += [depth_conv2d(256, 2 * num_classes, kernel=3, pad=1)] + landm_layers += [depth_conv2d(256, 2 * 10, kernel=3, pad=1)] + + loc_layers += [nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)] + landm_layers += [nn.Conv2d(256, 3 * 10, kernel_size=3, padding=1)] + return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers), nn.Sequential(*landm_layers) + + + def forward(self,inputs): + detections = list() + loc = list() + conf = list() + landm = list() + + x1 = self.conv1(inputs) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + x5 = self.conv5(x4) + x6 = self.conv6(x5) + x7 = self.conv7(x6) + x8 = self.conv8(x7) + detections.append(x8) + + x9 = self.conv9(x8) + x10 = self.conv10(x9) + x11 = self.conv11(x10) + detections.append(x11) + + x12 = self.conv12(x11) + x13 = self.conv13(x12) + detections.append(x13) + + x14= self.conv14(x13) + detections.append(x14) + + for (x, l, c, lam) in zip(detections, self.loc, self.conf, self.landm): + loc.append(l(x).permute(0, 2, 3, 1).contiguous()) + conf.append(c(x).permute(0, 2, 3, 1).contiguous()) + landm.append(lam(x).permute(0, 2, 3, 1).contiguous()) + + bbox_regressions = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) + classifications = torch.cat([o.view(o.size(0), -1, 2) for o in conf], 1) + ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1) + + + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output diff --git a/retinaface/models/net_slim.py b/retinaface/models/net_slim.py new file mode 100644 index 0000000000000000000000000000000000000000..e74910e269df2519001c95d9e99ad577c390e08d --- /dev/null +++ b/retinaface/models/net_slim.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torchvision.models.detection.backbone_utils as backbone_utils +import torchvision.models._utils as _utils +import torch.nn.functional as F +from collections import OrderedDict + +def conv_bn(inp, oup, stride = 1): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +def depth_conv2d(inp, oup, kernel=1, stride=1, pad=0): + return nn.Sequential( + nn.Conv2d(inp, inp, kernel_size = kernel, stride = stride, padding=pad, groups=inp), + nn.ReLU(inplace=True), + nn.Conv2d(inp, oup, kernel_size=1) + ) + +def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + +class Slim(nn.Module): + def __init__(self, cfg = None, phase = 'train'): + """ + :param cfg: Network related settings. + :param phase: train or test. + """ + super(Slim, self).__init__() + self.phase = phase + self.num_classes = 2 + + self.conv1 = conv_bn(3, 16, 2) + self.conv2 = conv_dw(16, 32, 1) + self.conv3 = conv_dw(32, 32, 2) + self.conv4 = conv_dw(32, 32, 1) + self.conv5 = conv_dw(32, 64, 2) + self.conv6 = conv_dw(64, 64, 1) + self.conv7 = conv_dw(64, 64, 1) + self.conv8 = conv_dw(64, 64, 1) + + self.conv9 = conv_dw(64, 128, 2) + self.conv10 = conv_dw(128, 128, 1) + self.conv11 = conv_dw(128, 128, 1) + + self.conv12 = conv_dw(128, 256, 2) + self.conv13 = conv_dw(256, 256, 1) + + self.conv14 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1), + nn.ReLU(inplace=True), + depth_conv2d(64, 256, kernel=3, stride=2, pad=1), + nn.ReLU(inplace=True) + ) + self.loc, self.conf, self.landm = self.multibox(self.num_classes); + + def multibox(self, num_classes): + loc_layers = [] + conf_layers = [] + landm_layers = [] + loc_layers += [depth_conv2d(64, 3 * 4, kernel=3, pad=1)] + conf_layers += [depth_conv2d(64, 3 * num_classes, kernel=3, pad=1)] + landm_layers += [depth_conv2d(64, 3 * 10, kernel=3, pad=1)] + + loc_layers += [depth_conv2d(128, 2 * 4, kernel=3, pad=1)] + conf_layers += [depth_conv2d(128, 2 * num_classes, kernel=3, pad=1)] + landm_layers += [depth_conv2d(128, 2 * 10, kernel=3, pad=1)] + + loc_layers += [depth_conv2d(256, 2 * 4, kernel=3, pad=1)] + conf_layers += [depth_conv2d(256, 2 * num_classes, kernel=3, pad=1)] + landm_layers += [depth_conv2d(256, 2 * 10, kernel=3, pad=1)] + + loc_layers += [nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)] + landm_layers += [nn.Conv2d(256, 3 * 10, kernel_size=3, padding=1)] + return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers), nn.Sequential(*landm_layers) + + + def forward(self,inputs): + detections = list() + loc = list() + conf = list() + landm = list() + + x1 = self.conv1(inputs) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + x5 = self.conv5(x4) + x6 = self.conv6(x5) + x7 = self.conv7(x6) + x8 = self.conv8(x7) + detections.append(x8) + + x9 = self.conv9(x8) + x10 = self.conv10(x9) + x11 = self.conv11(x10) + detections.append(x11) + + x12 = self.conv12(x11) + x13 = self.conv13(x12) + detections.append(x13) + + x14= self.conv14(x13) + detections.append(x14) + + for (x, l, c, lam) in zip(detections, self.loc, self.conf, self.landm): + loc.append(l(x).permute(0, 2, 3, 1).contiguous()) + conf.append(c(x).permute(0, 2, 3, 1).contiguous()) + landm.append(lam(x).permute(0, 2, 3, 1).contiguous()) + + bbox_regressions = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) + classifications = torch.cat([o.view(o.size(0), -1, 2) for o in conf], 1) + ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1) + + + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output diff --git a/retinaface/models/retinaface.py b/retinaface/models/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..e229befc93d4552eb7c443ff6232b657b9ab325e --- /dev/null +++ b/retinaface/models/retinaface.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torchvision.models.detection.backbone_utils as backbone_utils +import torchvision.models._utils as _utils +import torch.nn.functional as F +from collections import OrderedDict + +from models.net import MobileNetV1 as MobileNetV1 +from models.net import FPN as FPN +from models.net import SSH as SSH + + + +class ClassHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(ClassHead,self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*3,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + return out + # out = out.permute(0,2,3,1).contiguous() + # return out.view(out.shape[0], -1, 3) + +class BboxHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(BboxHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + return out + # out = out.permute(0,2,3,1).contiguous() + # return out.view(out.shape[0], -1, 4) + +class LandmarkHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(LandmarkHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + return out + # out = out.permute(0,2,3,1).contiguous() + # return out.view(out.shape[0], -1, 10) + +class RetinaFace(nn.Module): + def __init__(self, cfg = None, phase = 'train'): + """ + :param cfg: Network related settings. + :param phase: train or test. + """ + super(RetinaFace,self).__init__() + self.phase = phase + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + if cfg['pretrain']: + checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] # remove module. + new_state_dict[name] = v + # load params + backbone.load_state_dict(new_state_dict) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=cfg['pretrain']) + + self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list,out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels,anchor_num)) + return classhead + + def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels,anchor_num)) + return bboxhead + + def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels,anchor_num)) + return landmarkhead + + def forward(self,inputs): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + # return features + + bbox_regressions = [self.BboxHead[i](feature) for i, feature in enumerate(features)] + classifications = [self.ClassHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + # bbox_regressions = decode_batch(bbox_regressions, self.prior_data, self.variance) + # ldm_regressions = decode_landm_batch(ldm_regressions, self.prior_data, self.variance) + output = (bbox_regressions, classifications, ldm_regressions) + return output \ No newline at end of file diff --git a/retinaface/output.jpg b/retinaface/output.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3638f1d15adc20cea13ed8393c7db5418706252 Binary files /dev/null and b/retinaface/output.jpg differ diff --git a/retinaface/utils/__init__.py b/retinaface/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/retinaface/utils/box_utils.py b/retinaface/utils/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56562f8dde8da2f39600f4fc8b646cf3b4991f27 --- /dev/null +++ b/retinaface/utils/box_utils.py @@ -0,0 +1,379 @@ +import torch +import numpy as np + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2]-box_b[:, 0]) * + (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ endcoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence 3)landm preds. + """ + # jaccard index + overlaps = jaccard( + truths, + point_form(priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_np(loc, priors, variances): + boxes = np.concatenate(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ), dim=1) + return landms + + +def decode_landm_np(pre, priors, variances): + landms = np.concatenate((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ), axis=1) + return landms + + +def decode_batch(loc, priors, variances=[0.1, 0.2]): + priors = priors.expand(loc.shape[0], loc.shape[1], 4) + boxes = torch.cat(( + priors[..., :2] + loc[..., :2] * variances[0] * priors[..., 2:], + priors[..., 2:] * torch.exp(loc[..., 2:] * variances[1])), 2) + + # boxes = torch.cat(( + # boxes[..., :2] - boxes[..., 2:] / 2, + # boxes[..., 2:]), 2 + # ) + # boxes = torch.cat(( + # boxes[..., :2], + # boxes[..., 2:] + boxes[..., :2]), 2 + # ) + boxes[..., :2] -= boxes[..., 2:] / 2 + boxes[..., 2:] += boxes[..., :2] + return boxes + +def decode_landm_batch(pre, priors, variances=[0.1, 0.2]): + priors = priors.expand(pre.shape[0], pre.shape[1], 4) + landms = torch.cat((priors[..., :2] + pre[..., :2] * variances[0] * priors[..., 2:], + priors[..., :2] + pre[..., 2:4] * variances[0] * priors[..., 2:], + priors[..., :2] + pre[..., 4:6] * variances[0] * priors[..., 2:], + priors[..., :2] + pre[..., 6:8] * variances[0] * priors[..., 2:], + priors[..., :2] + pre[..., 8:10] * variances[0] * priors[..., 2:], + ), dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w*h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter/union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count + + diff --git a/retinaface/utils/fps.py b/retinaface/utils/fps.py new file mode 100644 index 0000000000000000000000000000000000000000..b48b37e4ccb8cd9883fdb8972cb123948642c901 --- /dev/null +++ b/retinaface/utils/fps.py @@ -0,0 +1,46 @@ +# import the necessary packages +import datetime + +class FPS: + def __init__(self, nframes=20): + # store the start time, end time, and total number of frames + # that were examined between the start and end intervals + self._start = None + self._end = None + self._numFrames = 0 + + self._nframes = nframes + self._start_n = None + self.fps_n = 0.0 + + def start(self): + # start the timer + self._start = datetime.datetime.now() + return self + + def stop(self): + # stop the timer + self._end = datetime.datetime.now() + + def update(self): + # increment the total number of frames examined during the + # start and end intervals + if (self._numFrames) % self._nframes == 0: + self._start_n = datetime.datetime.now() + + self._numFrames += 1 + + def elapsed(self): + # return the total number of seconds between the start and + # end interval + return (self._end - self._start).total_seconds() + + def fps(self): + # compute the (approximate) frames per second + return self._numFrames / self.elapsed() + + def get_fps_n(self): + # run after update + if (self._numFrames) % self._nframes == 0: + self.fps_n = self._nframes / ((datetime.datetime.now() - self._start_n).total_seconds()) + return self.fps_n \ No newline at end of file diff --git a/retinaface/utils/infer_utils.py b/retinaface/utils/infer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f15c78ac61afb989cc1f0b25fdb1ddb97738098 --- /dev/null +++ b/retinaface/utils/infer_utils.py @@ -0,0 +1,148 @@ +import numpy as np +import cv2 +# import tensorflow as tf +# from tensorflow_serving.apis.predict_pb2 import PredictRequest +# import grpc +# from tensorflow_serving.apis import prediction_service_pb2_grpc +# import time +# from tensorflow.keras.models import model_from_json +from imutils.paths import list_images +import os +import pickle +import torch +from skimage import transform +import cv2 +from pathlib import Path +from imutils.video import VideoStream, FileVideoStream + + +def check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + print('Missing keys:{}'.format(len(missing_keys))) + print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) + print('Used keys:{}'.format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + +def remove_prefix(state_dict, prefix): + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' + print('remove prefix \'{}\''.format(prefix)) + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x + return {f(key): value for key, value in state_dict.items()} + + +def load_model(model, pretrained_path, load_to_cpu): + print('Loading pretrained model from {}'.format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') + else: + pretrained_dict = remove_prefix(pretrained_dict, 'module.') + check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model + + +def judge_side_face(facial_landmarks): + wide_dist = np.linalg.norm(facial_landmarks[0] - facial_landmarks[1]) + high_dist = np.linalg.norm(facial_landmarks[0] - facial_landmarks[3]) + dist_rate = high_dist / wide_dist + + # cal std + vec_A = facial_landmarks[0] - facial_landmarks[2] + vec_B = facial_landmarks[1] - facial_landmarks[2] + vec_C = facial_landmarks[3] - facial_landmarks[2] + vec_D = facial_landmarks[4] - facial_landmarks[2] + dist_A = np.linalg.norm(vec_A) + dist_B = np.linalg.norm(vec_B) + dist_C = np.linalg.norm(vec_C) + dist_D = np.linalg.norm(vec_D) + + # cal rate + high_rate = dist_A / dist_C + width_rate = dist_C / dist_D + high_ratio_variance = np.fabs(high_rate - 1.1) # smaller is better + width_ratio_variance = np.fabs(width_rate - 1) + + if dist_rate < 1.3 and width_ratio_variance < 0.8: + return True, dist_rate, high_ratio_variance + return False, dist_rate, high_ratio_variance + + +def align_face(cv_img, dst): + """align face theo widerface + + Arguments: + cv_img {arr} -- Ảnh gốc + dst {arr}} -- landmark 5 điểm theo mtcnn + + Returns: + arr -- Ảnh face đã align + """ + face_img = np.zeros((112,112), dtype=np.uint8) + # Matrix standard lanmark same wider dataset + src = np.array([ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041] ], dtype=np.float32) + + tform = transform.SimilarityTransform() + tform.estimate(dst, src) + M = tform.params[0:2,:] + face_img = cv2.warpAffine(cv_img, M, (112,112), borderValue=0.0) + return face_img + + +class LoadImages: + def __init__(self, dir): + self.image_paths = list(Path(dir).glob("*.jpg")) + self.index = -1 + + def __iter__(self): + return self + + def __next__(self): + self.index += 1 + if self.index >= len(self.image_paths): + self.index = -1 + raise StopIteration + else: + image_path = str(self.image_paths[self.index]) + image = cv2.imread(str(image_path)) + return image + + def close(self): + pass + + +class LoadStream: + def __init__(self, path): + self.isfile = True + if "rtsp" in path: + self.stream = VideoStream(path).start() + self.isfile = False + else: + self.stream = FileVideoStream(path).start() + + def __iter__(self): + return self + + def __next__(self): + image = self.stream.read() + if self.isfile and image is None: + raise StopIteration + return image + + def close(self): + self.stream.stop() diff --git a/retinaface/utils/nms/__init__.py b/retinaface/utils/nms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/retinaface/utils/nms/py_cpu_nms.py b/retinaface/utils/nms/py_cpu_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..54e7b25fef72b518df6dcf8d6fb78b986796c6e3 --- /dev/null +++ b/retinaface/utils/nms/py_cpu_nms.py @@ -0,0 +1,38 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import numpy as np + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep diff --git a/retinaface/utils/timer.py b/retinaface/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b3b8098a5ad41f8d18d42b6b2fedb694aa5508 --- /dev/null +++ b/retinaface/utils/timer.py @@ -0,0 +1,40 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import time + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. diff --git a/retinaface/weights/RBF_Final.pth b/retinaface/weights/RBF_Final.pth new file mode 100644 index 0000000000000000000000000000000000000000..5f5e03c1d4b777defe0a6314075df27d12c30bea --- /dev/null +++ b/retinaface/weights/RBF_Final.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b22c45855bff5c4eb84e9973e2ee8a980639aa194ec66f33a87ce712f9d44c28 +size 1504452 diff --git a/retinaface/weights/mobilenet0.25_Final.pth b/retinaface/weights/mobilenet0.25_Final.pth new file mode 100644 index 0000000000000000000000000000000000000000..36079ee388ef1adfa7a99b06142228d4773dd55a --- /dev/null +++ b/retinaface/weights/mobilenet0.25_Final.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed6fa7d45a5179f03cb5ce82c73b0c6711372968e53782e88395acee88d319e1 +size 1789735 diff --git a/retinaface/weights/mobilenet0.25_epoch_842.pth b/retinaface/weights/mobilenet0.25_epoch_842.pth new file mode 100644 index 0000000000000000000000000000000000000000..f4f4b66775704b99eebd2c055a0aa351f70b3c9e --- /dev/null +++ b/retinaface/weights/mobilenet0.25_epoch_842.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79208aec7020085085f364cd0a833f08b4083e635c2f71bea80dfe3523c4439e +size 1841787 diff --git a/retinaface/weights/slim_Final.pth b/retinaface/weights/slim_Final.pth new file mode 100644 index 0000000000000000000000000000000000000000..f0e23bab25f54a561d58b28f71f0adc0147bf029 --- /dev/null +++ b/retinaface/weights/slim_Final.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76cdb587b2cf5f91b3836b66a652a0f7a91891758cd4fb0b07e4b7527af1c651 +size 1427150 diff --git a/turn.py b/turn.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7e8a56988cd8962a7c3591a5ca8c559be7d149 --- /dev/null +++ b/turn.py @@ -0,0 +1,34 @@ +# Copied from streamlit-webrtc/sample_utils/turn.py +import logging +import os + +import streamlit as st +from twilio.rest import Client + +logger = logging.getLogger(__name__) + + +@st.cache_data # type: ignore +def get_ice_servers(): + """Use Twilio's TURN server because Streamlit Community Cloud has changed + its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501 + We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too, + but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501 + See https://github.com/whitphx/streamlit-webrtc/issues/1213 + """ + + # Ref: https://www.twilio.com/docs/stun-turn/api + try: + account_sid = os.environ["TWILIO_ACCOUNT_SID"] + auth_token = os.environ["TWILIO_AUTH_TOKEN"] + except KeyError: + logger.warning( + "Twilio credentials are not set. Fallback to a free STUN server from Google." # noqa: E501 + ) + return [{"urls": ["stun:stun.l.google.com:19302"]}] + + client = Client(account_sid, auth_token) + + token = client.tokens.create() + + return token.ice_servers